Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion production_demo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.10,<3.14"
readme = "README.md"
dependencies = [
"pydantic>=2.10.6",
"restack-ai==0.0.91",
"restack-ai==0.0.94",
"watchfiles>=1.0.4",
"python-dotenv==1.0.1",
"openai>=1.61.0",
Expand Down
4 changes: 2 additions & 2 deletions production_demo/src/functions/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from restack_ai.function import function, FunctionFailure, log
from restack_ai.function import function, NonRetryableError, log
from openai import OpenAI
from pydantic import BaseModel

Expand All @@ -11,7 +11,7 @@ async def llm_evaluate(input: EvaluateInput) -> str:
client = OpenAI(base_url="http://192.168.205.1:1234/v1/",api_key="llmstudio")
except Exception as e:
log.error(f"Failed to create LLM client {e}")
raise FunctionFailure(f"Failed to create OpenAI client {e}", non_retryable=True) from e
raise NonRetryableError(message=f"Failed to create OpenAI client {e}") from e

prompt = (
f"Evaluate the following joke for humor, creativity, and originality. "
Expand Down
4 changes: 2 additions & 2 deletions production_demo/src/functions/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from restack_ai.function import function, log, FunctionFailure
from restack_ai.function import function, log, RetryableError

tries = 0

Expand All @@ -14,7 +14,7 @@ async def example_function(input: ExampleFunctionInput) -> str:

if tries == 0:
tries += 1
raise FunctionFailure(message="Simulated failure", non_retryable=False)
raise RetryableError(message="Simulated failure")

log.info("example function started", input=input)
return f"Hello, {input.name}!"
Expand Down
6 changes: 3 additions & 3 deletions production_demo/src/functions/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from restack_ai.function import function, FunctionFailure, log
from restack_ai.function import function, NonRetryableError, log
from openai import OpenAI

from pydantic import BaseModel
Expand All @@ -10,10 +10,10 @@ class GenerateInput(BaseModel):
async def llm_generate(input: GenerateInput) -> str:

try:
client = OpenAI(base_url="http://192.168.205.1:1234/v1/",api_key="llmstudio")
client = OpenAI(base_url="http://192.168.178.57:1234/v1/",api_key="llmstudio")
except Exception as e:
log.error(f"Failed to create LLM client {e}")
raise FunctionFailure(f"Failed to create OpenAI client {e}", non_retryable=True) from e
raise NonRetryableError(message=f"Failed to create OpenAI client {e}") from e

try:
response = client.chat.completions.create(
Expand Down
6 changes: 2 additions & 4 deletions production_demo/src/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
from src.functions.function import example_function
from src.functions.generate import llm_generate
from src.functions.evaluate import llm_evaluate

from src.workflows.workflow import ExampleWorkflow, ChildWorkflow

import webbrowser


async def main():

await asyncio.gather(
client.start_service(
workflows=[ExampleWorkflow, ChildWorkflow],
functions=[example_function],
options=ServiceOptions(
max_concurrent_workflow_runs=1000
)

),
client.start_service(
task_queue="llm",
Expand All @@ -31,7 +29,7 @@ async def main():
rate_limit=1,
max_concurrent_function_runs=1
)
)
),
)

def run_services():
Expand Down
52 changes: 29 additions & 23 deletions production_demo/src/workflows/child.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import timedelta
from pydantic import BaseModel, Field
from restack_ai.workflow import workflow, import_functions, log
from restack_ai.workflow import workflow, import_functions, log, NonRetryableError, RetryPolicy

with import_functions():
from src.functions.function import example_function, ExampleFunctionInput
Expand All @@ -14,28 +14,34 @@ class ChildWorkflowInput(BaseModel):
class ChildWorkflow:
@workflow.run
async def run(self, input: ChildWorkflowInput):

log.info("ChildWorkflow started")
await workflow.step(example_function, input=ExampleFunctionInput(name='John Doe'), start_to_close_timeout=timedelta(minutes=2))

await workflow.sleep(1)

generated_text = await workflow.step(
function=llm_generate,
function_input=GenerateInput(prompt=input.prompt),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=2)
)

evaluation = await workflow.step(
function=llm_evaluate,
function_input=EvaluateInput(generated_text=generated_text),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=5)
)

return {
"generated_text": generated_text,
"evaluation": evaluation
}

try:
await workflow.step(function=example_function, function_input=ExampleFunctionInput(name='John Doe'), start_to_close_timeout=timedelta(minutes=2), retry_policy=RetryPolicy(maximum_attempts=3))

await workflow.sleep(1)

generated_text = await workflow.step(
function=llm_generate,
function_input=GenerateInput(prompt=input.prompt),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=2)
)

evaluation = await workflow.step(
function=llm_evaluate,
function_input=EvaluateInput(generated_text=generated_text),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=5)
)

return {
"generated_text": generated_text,
"evaluation": evaluation
}
except Exception as e:
log.error(f"ChildWorkflow failed {e}")
raise NonRetryableError(message=f"ChildWorkflow failed {e}") from e


63 changes: 34 additions & 29 deletions production_demo/src/workflows/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from datetime import timedelta
from pydantic import BaseModel, Field
from restack_ai.workflow import workflow, log, workflow_info, import_functions
from restack_ai.workflow import workflow, log, workflow_info, import_functions, NonRetryableError
from .child import ChildWorkflow, ChildWorkflowInput

with import_functions():
Expand All @@ -14,34 +14,39 @@ class ExampleWorkflowInput(BaseModel):
class ExampleWorkflow:
@workflow.run
async def run(self, input: ExampleWorkflowInput):
# use the parent run id to create child workflow ids
parent_workflow_id = workflow_info().workflow_id

tasks = []
for i in range(input.amount):
log.info(f"Queue ChildWorkflow {i+1} for execution")
task = workflow.child_execute(
workflow=ChildWorkflow,
workflow_id=f"{parent_workflow_id}-child-execute-{i+1}",
input=ChildWorkflowInput(name=f"child workflow {i+1}")

try:
# use the parent run id to create child workflow ids
parent_workflow_id = workflow_info().workflow_id

tasks = []
for i in range(input.amount):
log.info(f"Queue ChildWorkflow {i+1} for execution")
task = workflow.child_execute(
workflow=ChildWorkflow,
workflow_id=f"{parent_workflow_id}-child-execute-{i+1}",
workflow_input=ChildWorkflowInput(prompt="Generate a random joke in max 20 words."),
)
tasks.append(task)

# Run all child workflows in parallel and wait for their results
results = await asyncio.gather(*tasks)

for i, result in enumerate(results, start=1):
log.info(f"ChildWorkflow {i} completed", result=result)

generated_text = await workflow.step(
function=llm_generate,
function_input=GenerateInput(prompt=f"Give me the top 3 unique jokes according to the results. {results}"),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=2)
)
tasks.append(task)

# Run all child workflows in parallel and wait for their results
results = await asyncio.gather(*tasks)

for i, result in enumerate(results, start=1):
log.info(f"ChildWorkflow {i} completed", result=result)

generated_text = await workflow.step(
function=llm_generate,
function_input=GenerateInput(prompt=f"Give me the top 3 unique jokes according to the results. {results}"),
task_queue="llm",
start_to_close_timeout=timedelta(minutes=2)
)

return {
"top_jokes": generated_text,
"results": results
}
return {
"top_jokes": generated_text,
"results": results
}

except Exception as e:
log.error(f"ExampleWorkflow failed {e}")
raise NonRetryableError(message=f"ExampleWorkflow failed {e}") from e