In [6]:
from pydantic import BaseModel, Field, ValidationError
from pydantic_ai import Agent, RunContext, ModelRetry
from typing import Any, List, Dict
import nest_asyncio

nest_asyncio.apply()

from typing import Any, Dict, List
from pydantic import BaseModel, Field, ValidationError
from pydantic_ai import Agent, RunContext, ModelRetry
from dslmodel.utils.pydantic_ai_tools import get_agent


# ------------------------------------------
# Shared State (Deps) Model
# ------------------------------------------

class WorkflowDeps(BaseModel):
    """Shared dependencies for the workflow."""
    user_input: str
    task_1_result: str = Field(default="", description="Result of Task 1.")
    task_2_result: str = Field(default="", description="Result of Task 2.")


# ------------------------------------------
# Structured Result Model
# ------------------------------------------

class WorkflowResult(BaseModel):
    """Final workflow result."""
    task: str
    result: str

    def as_dict(self) -> Dict[str, Any]:
        """Return a JSON-compatible dictionary representation."""
        return self.model_dump()


# ------------------------------------------
# Get Workflow Agent
# ------------------------------------------

workflow_agent = get_agent(
    system_prompt="You are a workflow manager. Execute tasks based on user input.",
    deps_type=WorkflowDeps,  # Shared state type
    result_type=Dict[str, Any],  # Expected result structure
    retries=2,
)


# ------------------------------------------
# Define Workflow Tools (Tasks)
# ------------------------------------------

@workflow_agent.tool
async def task_1(ctx: RunContext[WorkflowDeps]) -> Dict[str, Any]:
    """
    Task 1: Simulates a simple task that modifies the shared state.
    """
    result = f"Processed {ctx.deps.user_input} in Task 1."
    ctx.deps.task_1_result = result
    return WorkflowResult(task="Task 1", result=result).as_dict()


@workflow_agent.tool
async def task_2(ctx: RunContext[WorkflowDeps]) -> Dict[str, Any]:
    """
    Task 2: Uses the output of Task 1 to perform another operation.
    """
    if not ctx.deps.task_1_result:
        raise ModelRetry("Task 1 result is missing.")

    result = f"Task 2 received: {ctx.deps.task_1_result}."
    ctx.deps.task_2_result = result
    return WorkflowResult(task="Task 2", result=result).as_dict()


# ------------------------------------------
# Result Validation
# ------------------------------------------

@workflow_agent.result_validator
async def validate_final_result(ctx: RunContext[WorkflowDeps], result: Dict[str, Any]) -> Dict[str, Any]:
    """
    Validates the final workflow result.
    """
    if result.get("task") != "Task 2" or not result.get("result"):
        raise ModelRetry("Final result is invalid.")
    return result


# ------------------------------------------
# Workflow Manager
# ------------------------------------------

class Workflow:
    """Manages a sequence of tasks using Pydantic AI agent."""

    def __init__(self, tasks: List[str], initial_state: WorkflowDeps):
        self.tasks = tasks
        self.state = initial_state

    async def run(self) -> WorkflowDeps:
        """
        Executes all tasks sequentially, updating shared state.
        """
        for task_name in self.tasks:
            prompt = f"Execute {task_name}."
            # Execute the task via the agent
            result = await workflow_agent.run(prompt, deps=self.state)
            print(f"Task {task_name} completed: {result.data}")
        return self.state


# ------------------------------------------
# Example Workflow Execution
# ------------------------------------------

async def main():
    # Initialize shared state
    initial_state = WorkflowDeps(user_input="Example input")

    # Define the workflow with a sequence of tasks
    workflow = Workflow(
        tasks=["task_1", "task_2"],  # Names of tools to run
        initial_state=initial_state,
    )

    # Run the workflow
    final_state = await workflow.run()
    print("Final Workflow State:", final_state.model_dump())


# ------------------------------------------
# Run Example
# ------------------------------------------
import asyncio

asyncio.run(main())



INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"


UnexpectedModelBehavior: Exceeded maximum retries (2) for result validation