diff --git a/docs/agents.md b/docs/agents.md index 7bb13a864f..351254c248 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -121,6 +121,12 @@ async def main(): print(nodes) """ [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), ModelRequestNode( request=ModelRequest( parts=[ @@ -338,6 +344,7 @@ if __name__ == '__main__': print(output_messages) """ [ + '=== UserPromptNode: What will the weather be like in Paris on Tuesday? ===', '=== ModelRequestNode: streaming partial request tokens ===', '[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')', '[Request] Part 0 args_delta=ris","forecast_', diff --git a/docs/graph.md b/docs/graph.md index d7aa84f2de..c37aca7051 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -510,6 +510,7 @@ async def main(): #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() + #> Node: CountDown() #> Node: End(data=0) print('Final result:', run.result.output) # (3)! #> Final result: 0 diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index becc95619a..99cf6448ad 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -372,6 +372,12 @@ async def main(): print(nodes) ''' [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), ModelRequestNode( request=ModelRequest( parts=[ @@ -1355,6 +1361,12 @@ async def main(): print(nodes) ''' [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), ModelRequestNode( request=ModelRequest( parts=[ diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 0150e62528..3da34bca9a 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -607,6 +607,7 @@ async def main(): print(node_states) ''' [ + (Increment(), MyState(number=1)), (Increment(), MyState(number=1)), (Check42(), MyState(number=2)), (End(data=2), MyState(number=2)), @@ -621,6 +622,7 @@ async def main(): print(node_states) ''' [ + (Increment(), MyState(number=41)), (Increment(), MyState(number=41)), (Check42(), MyState(number=42)), (Increment(), MyState(number=42)), @@ -665,6 +667,7 @@ def __init__( self.deps = deps self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node + self._is_started: bool = False @property def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: @@ -777,8 +780,13 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" + if not self._is_started: + self._is_started = True + return self._next_node + if isinstance(self._next_node, End): raise StopAsyncIteration + return await self.next(self._next_node) def __repr__(self) -> str: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 82810d851b..09502de3c0 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -312,7 +312,9 @@ async def test_iter(): assert graph_iter.result assert graph_iter.result.output == 8 - assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)']) + assert node_reprs == snapshot( + ['Float2String(input_data=3.14)', "String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)'] + ) async def test_iter_next(mock_snapshot_id: object):