61
61
ConversationManager ,
62
62
SlidingWindowConversationManager ,
63
63
)
64
+ from .execution_state import ExecutionState
64
65
from .state import AgentState
65
66
66
67
logger = logging .getLogger (__name__ )
@@ -142,6 +143,17 @@ def caller(
142
143
Raises:
143
144
AttributeError: If the tool doesn't exist.
144
145
"""
146
+ if record_direct_tool_call is not None :
147
+ should_record_direct_tool_call = record_direct_tool_call
148
+ else :
149
+ should_record_direct_tool_call = self ._agent .record_direct_tool_call
150
+
151
+ if should_record_direct_tool_call and self ._agent .execution_state != ExecutionState .ASSISTANT :
152
+ raise RuntimeError (
153
+ f"execution_state=<{ self ._agent .execution_state } > "
154
+ f"| recording direct tool calls is only allowed in ASSISTANT execution state"
155
+ )
156
+
145
157
normalized_name = self ._find_normalized_tool_name (name )
146
158
147
159
# Create unique tool ID and set up the tool request
@@ -167,11 +179,6 @@ def tcall() -> ToolResult:
167
179
future = executor .submit (tcall )
168
180
tool_result = future .result ()
169
181
170
- if record_direct_tool_call is not None :
171
- should_record_direct_tool_call = record_direct_tool_call
172
- else :
173
- should_record_direct_tool_call = self ._agent .record_direct_tool_call
174
-
175
182
if should_record_direct_tool_call :
176
183
# Create a record of this tool execution in the message history
177
184
self ._agent ._record_tool_execution (tool_use , tool_result , user_message_override )
@@ -349,6 +356,10 @@ def __init__(
349
356
self .hooks .add_hook (hook )
350
357
self .hooks .invoke_callbacks (AgentInitializedEvent (agent = self ))
351
358
359
+ self .execution_state = ExecutionState .ASSISTANT
360
+
361
+ self .interrupts = {}
362
+
352
363
@property
353
364
def tool (self ) -> ToolCaller :
354
365
"""Call tool as a function.
@@ -540,6 +551,7 @@ async def stream_async(
540
551
Args:
541
552
prompt: User input in various formats:
542
553
- str: Simple text input
554
+ - ContentBlock: Multi-modal content block
543
555
- list[ContentBlock]: Multi-modal content blocks
544
556
- list[Message]: Complete messages with roles
545
557
- None: Use existing conversation history
@@ -564,6 +576,8 @@ async def stream_async(
564
576
yield event["data"]
565
577
```
566
578
"""
579
+ self ._resume (prompt )
580
+
567
581
callback_handler = kwargs .get ("callback_handler" , self .callback_handler )
568
582
569
583
# Process input and get message to add (if any)
@@ -585,6 +599,11 @@ async def stream_async(
585
599
586
600
result = AgentResult (* event ["stop" ])
587
601
callback_handler (result = result )
602
+
603
+ if result .stop_reason == "interrupt" :
604
+ self .execution_state = ExecutionState .INTERRUPT
605
+ self .interrupts = {interrupt .name : interrupt for interrupt in result .interrupts }
606
+
588
607
yield AgentResultEvent (result = result ).as_dict ()
589
608
590
609
self ._end_agent_trace_span (response = result )
@@ -593,6 +612,16 @@ async def stream_async(
593
612
self ._end_agent_trace_span (error = e )
594
613
raise
595
614
615
+ def _resume (self , prompt : AgentInput ) -> None :
616
+ if self .execution_state != ExecutionState .INTERRUPT :
617
+ return
618
+
619
+ if not isinstance (prompt , dict ) or "resume" not in prompt :
620
+ raise ValueError ("<TODO>." )
621
+
622
+ for interrupt in self .interrupts .values ():
623
+ interrupt .resume = prompt ["resume" ][interrupt .name ]
624
+
596
625
async def _run_loop (self , messages : Messages , invocation_state : dict [str , Any ]) -> AsyncGenerator [TypedEvent , None ]:
597
626
"""Execute the agent's event loop with the given message and parameters.
598
627
@@ -673,6 +702,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
673
702
if isinstance (prompt , str ):
674
703
# String input - convert to user message
675
704
messages = [{"role" : "user" , "content" : [{"text" : prompt }]}]
705
+ elif isinstance (prompt , dict ):
706
+ messages = [{"role" : "user" , "content" : prompt }] if "resume" not in prompt else []
676
707
elif isinstance (prompt , list ):
677
708
if len (prompt ) == 0 :
678
709
# Empty list
@@ -692,7 +723,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
692
723
else :
693
724
messages = []
694
725
if messages is None :
695
- raise ValueError ("Input prompt must be of type: `str | list[Contentblock] | Messages | None`." )
726
+ raise ValueError (
727
+ "Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`."
728
+ )
696
729
return messages
697
730
698
731
def _record_tool_execution (
0 commit comments