Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions patchwork/common/multiturn_strategy/planning_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import asyncio
from functools import partial
from typing import Any, Optional, Union

from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.agent import AgentRunResult

from patchwork.common.client.llm.protocol import LlmClient
from patchwork.common.client.llm.utils import example_json_to_base_model
from patchwork.common.tools import Tool


class StepCompletedResult(BaseModel):
is_step_completed: bool


class PlanCompletedResult(BaseModel):
is_plan_completed: bool


class ExecutionResult(BaseModel):
json_data: str
message: str
is_completed: bool


class _Plan:
def __init__(self, initial_plan: Optional[list[str]] = None):
self.__plan = initial_plan or []
self.__cursor = 0

def advance(self) -> bool:
self.__cursor += 1
return self.__cursor < len(self.__plan)

def is_empty(self) -> bool:
return len(self.__plan) == 0

def register_steps(self, agent: Agent):
agent.tool_plain(self.get_current_plan)
agent.tool_plain(self.get_current_step)
agent.tool_plain(self.get_current_step_index)
agent.tool_plain(self.add_step)
agent.tool_plain(self.delete_step)

def get_current_plan(self) -> str:
return "\n".join([f"{i}. {step}" for i, step in enumerate(self.__plan)])

def get_current_step(self) -> str:
if len(self.__plan) == 0:
return "There is currently no plan"

return self.__plan[self.__cursor]

def get_current_step_index(self) -> int:
return self.__cursor

def add_step(self, index: int, step: str) -> str:
if index < 0:
return "index cannot be a negative number"

if index >= len(self.__plan):
insertion_func = self.__plan.append
else:
insertion_func = partial(self.__plan.insert, index)

insertion_func(step)
return "Added step\nCurrent plan:\n" + self.get_current_plan()

def delete_step(self, step: str) -> str:
try:
i = self.__plan.index(step)
self.__plan.pop(i)
return self.get_current_plan()
except ValueError:
return "Step not found in plan\nCurrent plan:\n" + self.get_current_plan()


class PlanningStrategy:
def __init__(
self,
llm_client: LlmClient,
planner_system_prompt: str,
executor_system_prompt: str,
executor_tool_set: dict[str, Tool],
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
):
"""
Use this like this::

class DatabaseAgent(Step, input_class=DatabaseAgentInputs, output_class=DatabaseAgentOutputs):
def __init__(self, inputs):
super().__init__(inputs)

llm_client = AioLlmClient.create_aio_client(inputs)

data = inputs.get("prompt_value", {})
self.task = mustache_render(inputs["task"], data)

db_dialect = inputs["db_dialect"]
self.planner = PlanningStrategy(
llm_client,
planner_system_prompt=f'''\\
You are a {db_dialect} database query planning assistant. You are tasked to plan the steps to assist with the provided task.
You will not execute the steps in the plan. The user will do that instead.
The first step of the plan should be as follows:
1. Tell me all tables currently available.
After the list of table names is provided, get the DDL of the tables that is relevant.
Your steps should be clear and concise like the following example:
1. Tell me the column descriptions of the table `orders`.
2. Execute the SQL Query: `SELECT * FROM orders`
After every step, you will be asked to edit the plan so feel free to plan 1 step at a time.
''',
executor_system_prompt=f'''\\
You are a {db_dialect} database query execution assistant. You will be provided instructions on what to do.
''',
)

def run(self) -> dict:
planner_response = self.planner.run(self.task, 10)
return {**planner_response, **self.planner.usage()}

"""
self.planner = Agent(
llm_client,
name="Planner",
system_prompt=planner_system_prompt,
model_settings=dict(
parallel_tool_calls=False,
model="gemini-2.0-flash",
),
)

self.plan = _Plan()
self.plan.register_steps(self.planner)

self.executor = Agent(
llm_client,
name="Executor",
system_prompt=executor_system_prompt,
result_type=ExecutionResult,
tools=[tool.to_pydantic_ai_function_tool() for tool in executor_tool_set.values()],
model_settings=dict(
parallel_tool_calls=False,
model="gemini-2.0-flash",
),
)

self.__summariser = Agent(
llm_client,
result_retries=5,
system_prompt="""\
Please summarise the conversation given and provide the result in the structure that is asked of you.
""",
result_type=example_json_to_base_model(example_json),
model_settings=dict(
parallel_tool_calls=False,
model="gemini-2.0-flash",
),
)

self.reset()

def reset(self):
self.__request_tokens = 0
self.__response_tokens = 0

def usage(self):
return {
"request_tokens": self.__request_tokens,
"response_tokens": self.__response_tokens,
}

def __agent_run(self, agent: Agent, prompt: str, **kwargs) -> AgentRunResult[Any]:
loop = asyncio.new_event_loop()
planner_response = loop.run_until_complete(agent.run(prompt, **kwargs))
loop.close()
self.__request_tokens += planner_response.usage().request_tokens
self.__response_tokens += planner_response.usage().response_tokens

return planner_response

def run(self, task: str, conversation_limit: int = 10) -> dict:

planner_response = self.__agent_run(self.planner, f"Produce the initial plan for {task}")
planner_history = planner_response.all_messages()
if self.plan.is_empty():
planner_response = self.__agent_run(
self.planner, f"Please use the tools provided to setup the plan", message_history=planner_history
)
planner_history = planner_response.all_messages()

for i in range(conversation_limit):
step = self.plan.get_current_step()
executor_prompt = f"Please execute the following task: {step}"
response = self.__agent_run(self.executor, executor_prompt)

plan_str = self.plan.get_current_plan()
step_index = self.plan.get_current_step_index()
planner_prompt = f"""\
The current plan is:
{plan_str}

We are current at {step_index}.
If the current step is not completed, edit the current step.

The execution result for the step {step_index} is:
{response.data}

"""
planner_response = self.__agent_run(
self.planner,
planner_prompt,
message_history=planner_history,
result_type=StepCompletedResult,
)
planner_history = planner_response.all_messages()
if not planner_response.data.is_step_completed:
continue

if self.plan.advance():
continue

planner_response = self.__agent_run(
self.planner,
"Is the task completed? If the task is not completed please add more steps using the tools provided.",
message_history=planner_history,
result_type=PlanCompletedResult,
)
if planner_response.data.is_plan_completed:
break

final_result = self.__agent_run(
self.__summariser,
"From the actions taken by the assistant. Please give me the result.",
message_history=planner_history,
)

return final_result.data.dict()
38 changes: 38 additions & 0 deletions patchwork/common/tools/db_query_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing_extensions import Any, Union

from patchwork.common.tools import Tool
from patchwork.steps import CallSQL


class DatabaseQueryTool(Tool, tool_name="db_query_tool"):
def __init__(self, inputs: dict[str, Any]):
super().__init__()
self.db_settings = inputs.copy()
self.db_dialect = inputs.get("db_dialect", "SQL")

@property
def json_schema(self) -> dict:
return {
"name": "db_query_tool",
"description": f"""\
Run SQL Query on current {self.db_dialect} database.
""",
"input_schema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Database query to run.",
}
},
"required": ["query"],
},
}

def execute(self, query: str) -> Union[list[dict[str, Any]], str]:
db_settings = self.db_settings.copy()
db_settings["db_query"] = query
try:
return CallSQL(db_settings).run().get("results", [])
except Exception as e:
return str(e)
18 changes: 14 additions & 4 deletions patchwork/steps/AgenticLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ class AgenticLLMInputs(TypedDict, total=False):
user_prompt: str
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
openai_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
),
]
anthropic_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
str,
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
]
patched_api_key: Annotated[
str,
Expand All @@ -31,10 +35,16 @@ class AgenticLLMInputs(TypedDict, total=False):
),
]
google_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
),
]
client_is_gcp: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
),
]


Expand Down
18 changes: 14 additions & 4 deletions patchwork/steps/CallLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ class CallLLMInputs(TypedDict, total=False):
model_args: Annotated[str, StepTypeConfig(is_config=True)]
client_args: Annotated[str, StepTypeConfig(is_config=True)]
openai_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
),
]
anthropic_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
str,
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
]
patched_api_key: Annotated[
str,
Expand All @@ -33,10 +37,16 @@ class CallLLMInputs(TypedDict, total=False):
),
]
google_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
),
]
client_is_gcp: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
str,
StepTypeConfig(
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
),
]
file: Annotated[str, StepTypeConfig(is_path=True)]

Expand Down
Loading