diff --git a/examples/tutorials/00_sync/010_multiturn/project/acp.py b/examples/tutorials/00_sync/010_multiturn/project/acp.py index 0067cec3..d0a49118 100644 --- a/examples/tutorials/00_sync/010_multiturn/project/acp.py +++ b/examples/tutorials/00_sync/010_multiturn/project/acp.py @@ -39,10 +39,16 @@ async def handle_message_send( if params.content.author != "user": raise ValueError(f"Expected user message, got {params.content.author}") - if not os.environ.get("OPENAI_API_KEY"): + if not os.environ.get("SGP_API_KEY"): return TextContent( author="agent", - content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", + content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", + ) + + if not os.environ.get("SGP_ACCOUNT_ID"): + return TextContent( + author="agent", + content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", ) ######################################################### @@ -54,7 +60,7 @@ async def handle_message_send( if not task_state: # If the state doesn't exist, create it. - state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini") + state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini") task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state) else: state = StateModel.model_validate(task_state.state) @@ -96,7 +102,7 @@ async def handle_message_send( ######################################################### # Call an LLM to respond to the user's message - chat_completion = await adk.providers.litellm.chat_completion( + chat_completion = await adk.providers.sgp.chat_completion( llm_config=LLMConfig(model=state.model, messages=llm_messages), trace_id=params.task.id, ) diff --git a/examples/tutorials/00_sync/020_streaming/project/acp.py b/examples/tutorials/00_sync/020_streaming/project/acp.py index 787f2dae..107523f7 100644 --- a/examples/tutorials/00_sync/020_streaming/project/acp.py +++ b/examples/tutorials/00_sync/020_streaming/project/acp.py @@ -41,13 +41,23 @@ async def handle_message_send( if params.content.author != "user": raise ValueError(f"Expected user message, got {params.content.author}") - if not os.environ.get("OPENAI_API_KEY"): + if not os.environ.get("SGP_API_KEY"): yield StreamTaskMessageFull( index=0, type="full", content=TextContent( author="agent", - content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", + content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", + ), + ) + + if not os.environ.get("SGP_ACCOUNT_ID"): + yield StreamTaskMessageFull( + index=0, + type="full", + content=TextContent( + author="agent", + content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", ), ) @@ -56,7 +66,7 @@ async def handle_message_send( if not task_state: # If the state doesn't exist, create it. - state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini") + state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini") task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state) else: state = StateModel.model_validate(task_state.state) @@ -83,7 +93,7 @@ async def handle_message_send( # The Agentex server automatically commits input and output messages to the database so you don't need to do this yourself, simply process the input content and return the output content. message_index = 0 - async for chunk in adk.providers.litellm.chat_completion_stream( + async for chunk in adk.providers.sgp.chat_completion_stream( llm_config=LLMConfig(model=state.model, messages=llm_messages, stream=True), trace_id=params.task.id, ): diff --git a/src/agentex/lib/adk/providers/_modules/sgp.py b/src/agentex/lib/adk/providers/_modules/sgp.py index 52c20e09..3c8b46fe 100644 --- a/src/agentex/lib/adk/providers/_modules/sgp.py +++ b/src/agentex/lib/adk/providers/_modules/sgp.py @@ -1,20 +1,29 @@ from datetime import timedelta +from typing import AsyncGenerator from agentex.lib.adk.utils._modules.client import create_async_agentex_client from scale_gp import SGPClient, SGPClientError from temporalio.common import RetryPolicy from agentex import AsyncAgentex +from agentex.lib.core.adapters.llm.adapter_sgp import SGPLLMGateway +from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository +from agentex.lib.core.services.adk.providers.litellm import LiteLLMService from agentex.lib.core.services.adk.providers.sgp import SGPService +from agentex.lib.core.services.adk.streaming import StreamingService from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers +from agentex.lib.core.temporal.activities.adk.providers.litellm_activities import ChatCompletionParams, \ + LiteLLMActivityName, ChatCompletionAutoSendParams, ChatCompletionStreamAutoSendParams from agentex.lib.core.temporal.activities.adk.providers.sgp_activities import ( DownloadFileParams, FileContentResponse, SGPActivityName, ) from agentex.lib.core.tracing.tracer import AsyncTracer +from agentex.lib.types.llm_messages import LLMConfig, Completion from agentex.lib.utils.logging import make_logger from agentex.lib.utils.temporal import in_temporal_workflow +from agentex.types import TaskMessage logger = make_logger(__name__) @@ -30,6 +39,7 @@ class SGPModule: def __init__( self, sgp_service: SGPService | None = None, + litellm_service: LiteLLMService | None = None, ): if sgp_service is None: try: @@ -42,6 +52,21 @@ def __init__( else: self._sgp_service = sgp_service + agentex_client = create_async_agentex_client() + stream_repository = RedisStreamRepository() + streaming_service = StreamingService( + agentex_client=agentex_client, + stream_repository=stream_repository, + ) + litellm_gateway = SGPLLMGateway() + tracer = AsyncTracer(agentex_client) + self._litellm_service = LiteLLMService( + agentex_client=agentex_client, + llm_gateway=litellm_gateway, + streaming_service=streaming_service, + tracer=tracer, + ) + async def download_file_content( self, params: DownloadFileParams, @@ -84,3 +109,178 @@ async def download_file_content( file_id=params.file_id, filename=params.filename, ) + + async def chat_completion( + self, + llm_config: LLMConfig, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=120), + heartbeat_timeout: timedelta = timedelta(seconds=120), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Completion: + """ + Perform a chat completion using LiteLLM. + + Args: + llm_config (LLMConfig): The configuration for the LLM. + trace_id (Optional[str]): The trace ID for tracing. + parent_span_id (Optional[str]): The parent span ID for tracing. + start_to_close_timeout (timedelta): The start to close timeout. + heartbeat_timeout (timedelta): The heartbeat timeout. + retry_policy (RetryPolicy): The retry policy. + + Returns: + Completion: An OpenAI compatible Completion object + """ + if in_temporal_workflow(): + params = ChatCompletionParams( + trace_id=trace_id, parent_span_id=parent_span_id, llm_config=llm_config + ) + return await ActivityHelpers.execute_activity( + activity_name=LiteLLMActivityName.CHAT_COMPLETION, + request=params, + response_type=Completion, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + ) + else: + return await self._litellm_service.chat_completion( + llm_config=llm_config, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def chat_completion_auto_send( + self, + task_id: str, + llm_config: LLMConfig, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=120), + heartbeat_timeout: timedelta = timedelta(seconds=120), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> TaskMessage | None: + """ + Chat completion with automatic TaskMessage creation. + + Args: + task_id (str): The ID of the task. + llm_config (LLMConfig): The configuration for the LLM (must have stream=False). + trace_id (Optional[str]): The trace ID for tracing. + parent_span_id (Optional[str]): The parent span ID for tracing. + start_to_close_timeout (timedelta): The start to close timeout. + heartbeat_timeout (timedelta): The heartbeat timeout. + retry_policy (RetryPolicy): The retry policy. + + Returns: + TaskMessage: The final TaskMessage + """ + if in_temporal_workflow(): + # Use streaming activity with stream=False for non-streaming auto-send + params = ChatCompletionAutoSendParams( + trace_id=trace_id, + parent_span_id=parent_span_id, + task_id=task_id, + llm_config=llm_config, + ) + return await ActivityHelpers.execute_activity( + activity_name=LiteLLMActivityName.CHAT_COMPLETION_AUTO_SEND, + request=params, + response_type=TaskMessage, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + ) + else: + return await self._litellm_service.chat_completion_auto_send( + task_id=task_id, + llm_config=llm_config, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def chat_completion_stream( + self, + llm_config: LLMConfig, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> AsyncGenerator[Completion, None]: + """ + Stream chat completion chunks using LiteLLM. + + DEFAULT: Returns raw streaming chunks for manual handling. + + NOTE: This method does NOT work in Temporal workflows! + Temporal activities cannot return generators. Use chat_completion_stream_auto_send() instead. + + Args: + llm_config (LLMConfig): The configuration for the LLM (must have stream=True). + trace_id (Optional[str]): The trace ID for tracing. + parent_span_id (Optional[str]): The parent span ID for tracing. + start_to_close_timeout (timedelta): The start to close timeout. + heartbeat_timeout (timedelta): The heartbeat timeout. + retry_policy (RetryPolicy): The retry policy. + + Returns: + AsyncGenerator[Completion, None]: Generator yielding completion chunks + + Raises: + ValueError: If called from within a Temporal workflow + """ + # Delegate to service - it handles temporal workflow checks + async for chunk in self._litellm_service.chat_completion_stream( + llm_config=llm_config, + trace_id=trace_id, + parent_span_id=parent_span_id, + ): + yield chunk + + async def chat_completion_stream_auto_send( + self, + task_id: str, + llm_config: LLMConfig, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=120), + heartbeat_timeout: timedelta = timedelta(seconds=120), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> TaskMessage | None: + """ + Stream chat completion with automatic TaskMessage creation and streaming. + + Args: + task_id (str): The ID of the task to run the agent for. + llm_config (LLMConfig): The configuration for the LLM (must have stream=True). + trace_id (Optional[str]): The trace ID for tracing. + parent_span_id (Optional[str]): The parent span ID for tracing. + start_to_close_timeout (timedelta): The start to close timeout. + heartbeat_timeout (timedelta): The heartbeat timeout. + retry_policy (RetryPolicy): The retry policy. + + Returns: + TaskMessage: The final TaskMessage after streaming is complete + """ + if in_temporal_workflow(): + params = ChatCompletionStreamAutoSendParams( + trace_id=trace_id, + parent_span_id=parent_span_id, + task_id=task_id, + llm_config=llm_config, + ) + return await ActivityHelpers.execute_activity( + activity_name=LiteLLMActivityName.CHAT_COMPLETION_STREAM_AUTO_SEND, + request=params, + response_type=TaskMessage, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + ) + else: + return await self._litellm_service.chat_completion_stream_auto_send( + task_id=task_id, + llm_config=llm_config, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) \ No newline at end of file diff --git a/src/agentex/lib/core/adapters/llm/adapter_sgp.py b/src/agentex/lib/core/adapters/llm/adapter_sgp.py index a14e66a2..4b51ac9f 100644 --- a/src/agentex/lib/core/adapters/llm/adapter_sgp.py +++ b/src/agentex/lib/core/adapters/llm/adapter_sgp.py @@ -11,10 +11,14 @@ class SGPLLMGateway(LLMGateway): - def __init__(self, sgp_api_key: str | None = None): - self.sync_client = SGPClient(api_key=os.environ.get("SGP_API_KEY", sgp_api_key)) + def __init__(self, sgp_api_key: str | None = None, sgp_account_id: str | None = None): + self.sync_client = SGPClient( + api_key=os.environ.get("SGP_API_KEY", sgp_api_key), + account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id) + ) self.async_client = AsyncSGPClient( - api_key=os.environ.get("SGP_API_KEY", sgp_api_key) + api_key=os.environ.get("SGP_API_KEY", sgp_api_key), + account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id) ) def completion(self, *args, **kwargs) -> Completion: