Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions examples/tutorials/00_sync/010_multiturn/project/acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ async def handle_message_send(
if params.content.author != "user":
raise ValueError(f"Expected user message, got {params.content.author}")

if not os.environ.get("OPENAI_API_KEY"):
if not os.environ.get("SGP_API_KEY"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these checks here or will the SGPClient already raise SGPClientError exceptions when these keys are missing? Just wondering b/c it adds a lot of boilerplate to each project. Are you trying to protect against error message propagation issues?

return TextContent(
author="agent",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
)

if not os.environ.get("SGP_ACCOUNT_ID"):
return TextContent(
author="agent",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
)

#########################################################
Expand All @@ -54,7 +60,7 @@ async def handle_message_send(

if not task_state:
# If the state doesn't exist, create it.
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini")
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini")
task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state)
else:
state = StateModel.model_validate(task_state.state)
Expand Down Expand Up @@ -96,7 +102,7 @@ async def handle_message_send(
#########################################################

# Call an LLM to respond to the user's message
chat_completion = await adk.providers.litellm.chat_completion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit for the future:

Is the goal for each of these providers to be drop in replacements for each other? If so, for the future, would you want to consider some type of "LLMProviderInterface" spec that these providers adhere to? And then have a providers.auto option which checks env vars and initializes the appropriate one instead of hard coding the selection in the tutorials?

chat_completion = await adk.providers.sgp.chat_completion(
llm_config=LLMConfig(model=state.model, messages=llm_messages),
trace_id=params.task.id,
)
Expand Down
18 changes: 14 additions & 4 deletions examples/tutorials/00_sync/020_streaming/project/acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,23 @@ async def handle_message_send(
if params.content.author != "user":
raise ValueError(f"Expected user message, got {params.content.author}")

if not os.environ.get("OPENAI_API_KEY"):
if not os.environ.get("SGP_API_KEY"):
yield StreamTaskMessageFull(
index=0,
type="full",
content=TextContent(
author="agent",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
),
)

if not os.environ.get("SGP_ACCOUNT_ID"):
yield StreamTaskMessageFull(
index=0,
type="full",
content=TextContent(
author="agent",
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
),
)

Expand All @@ -56,7 +66,7 @@ async def handle_message_send(

if not task_state:
# If the state doesn't exist, create it.
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini")
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini")
task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state)
else:
state = StateModel.model_validate(task_state.state)
Expand All @@ -83,7 +93,7 @@ async def handle_message_send(
# The Agentex server automatically commits input and output messages to the database so you don't need to do this yourself, simply process the input content and return the output content.

message_index = 0
async for chunk in adk.providers.litellm.chat_completion_stream(
async for chunk in adk.providers.sgp.chat_completion_stream(
llm_config=LLMConfig(model=state.model, messages=llm_messages, stream=True),
trace_id=params.task.id,
):
Expand Down
200 changes: 200 additions & 0 deletions src/agentex/lib/adk/providers/_modules/sgp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
from datetime import timedelta
from typing import AsyncGenerator

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from scale_gp import SGPClient, SGPClientError
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.core.adapters.llm.adapter_sgp import SGPLLMGateway
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
from agentex.lib.core.services.adk.providers.litellm import LiteLLMService
from agentex.lib.core.services.adk.providers.sgp import SGPService
from agentex.lib.core.services.adk.streaming import StreamingService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.providers.litellm_activities import ChatCompletionParams, \
LiteLLMActivityName, ChatCompletionAutoSendParams, ChatCompletionStreamAutoSendParams
from agentex.lib.core.temporal.activities.adk.providers.sgp_activities import (
DownloadFileParams,
FileContentResponse,
SGPActivityName,
)
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.lib.types.llm_messages import LLMConfig, Completion
from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.temporal import in_temporal_workflow
from agentex.types import TaskMessage

logger = make_logger(__name__)

Expand All @@ -30,6 +39,7 @@ class SGPModule:
def __init__(
self,
sgp_service: SGPService | None = None,
litellm_service: LiteLLMService | None = None,
):
if sgp_service is None:
try:
Expand All @@ -42,6 +52,21 @@ def __init__(
else:
self._sgp_service = sgp_service

agentex_client = create_async_agentex_client()
stream_repository = RedisStreamRepository()
streaming_service = StreamingService(
agentex_client=agentex_client,
stream_repository=stream_repository,
)
litellm_gateway = SGPLLMGateway()
tracer = AsyncTracer(agentex_client)
self._litellm_service = LiteLLMService(
agentex_client=agentex_client,
llm_gateway=litellm_gateway,
streaming_service=streaming_service,
tracer=tracer,
)

async def download_file_content(
self,
params: DownloadFileParams,
Expand Down Expand Up @@ -84,3 +109,178 @@ async def download_file_content(
file_id=params.file_id,
filename=params.filename,
)

async def chat_completion(
self,
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
start_to_close_timeout: timedelta = timedelta(seconds=120),
heartbeat_timeout: timedelta = timedelta(seconds=120),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
) -> Completion:
"""
Perform a chat completion using LiteLLM.

Args:
llm_config (LLMConfig): The configuration for the LLM.
trace_id (Optional[str]): The trace ID for tracing.
parent_span_id (Optional[str]): The parent span ID for tracing.
start_to_close_timeout (timedelta): The start to close timeout.
heartbeat_timeout (timedelta): The heartbeat timeout.
retry_policy (RetryPolicy): The retry policy.

Returns:
Completion: An OpenAI compatible Completion object
"""
if in_temporal_workflow():
params = ChatCompletionParams(
trace_id=trace_id, parent_span_id=parent_span_id, llm_config=llm_config
)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION,
request=params,
response_type=Completion,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
)
else:
return await self._litellm_service.chat_completion(
llm_config=llm_config,
trace_id=trace_id,
parent_span_id=parent_span_id,
)

async def chat_completion_auto_send(
self,
task_id: str,
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
start_to_close_timeout: timedelta = timedelta(seconds=120),
heartbeat_timeout: timedelta = timedelta(seconds=120),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
) -> TaskMessage | None:
"""
Chat completion with automatic TaskMessage creation.

Args:
task_id (str): The ID of the task.
llm_config (LLMConfig): The configuration for the LLM (must have stream=False).
trace_id (Optional[str]): The trace ID for tracing.
parent_span_id (Optional[str]): The parent span ID for tracing.
start_to_close_timeout (timedelta): The start to close timeout.
heartbeat_timeout (timedelta): The heartbeat timeout.
retry_policy (RetryPolicy): The retry policy.

Returns:
TaskMessage: The final TaskMessage
"""
if in_temporal_workflow():
# Use streaming activity with stream=False for non-streaming auto-send
params = ChatCompletionAutoSendParams(
trace_id=trace_id,
parent_span_id=parent_span_id,
task_id=task_id,
llm_config=llm_config,
)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION_AUTO_SEND,
request=params,
response_type=TaskMessage,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
)
else:
return await self._litellm_service.chat_completion_auto_send(
task_id=task_id,
llm_config=llm_config,
trace_id=trace_id,
parent_span_id=parent_span_id,
)

async def chat_completion_stream(
self,
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> AsyncGenerator[Completion, None]:
"""
Stream chat completion chunks using LiteLLM.

DEFAULT: Returns raw streaming chunks for manual handling.

NOTE: This method does NOT work in Temporal workflows!
Temporal activities cannot return generators. Use chat_completion_stream_auto_send() instead.

Args:
llm_config (LLMConfig): The configuration for the LLM (must have stream=True).
trace_id (Optional[str]): The trace ID for tracing.
parent_span_id (Optional[str]): The parent span ID for tracing.
start_to_close_timeout (timedelta): The start to close timeout.
heartbeat_timeout (timedelta): The heartbeat timeout.
retry_policy (RetryPolicy): The retry policy.

Returns:
AsyncGenerator[Completion, None]: Generator yielding completion chunks

Raises:
ValueError: If called from within a Temporal workflow
"""
# Delegate to service - it handles temporal workflow checks
async for chunk in self._litellm_service.chat_completion_stream(
llm_config=llm_config,
trace_id=trace_id,
parent_span_id=parent_span_id,
):
yield chunk

async def chat_completion_stream_auto_send(
self,
task_id: str,
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
start_to_close_timeout: timedelta = timedelta(seconds=120),
heartbeat_timeout: timedelta = timedelta(seconds=120),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
) -> TaskMessage | None:
"""
Stream chat completion with automatic TaskMessage creation and streaming.

Args:
task_id (str): The ID of the task to run the agent for.
llm_config (LLMConfig): The configuration for the LLM (must have stream=True).
trace_id (Optional[str]): The trace ID for tracing.
parent_span_id (Optional[str]): The parent span ID for tracing.
start_to_close_timeout (timedelta): The start to close timeout.
heartbeat_timeout (timedelta): The heartbeat timeout.
retry_policy (RetryPolicy): The retry policy.

Returns:
TaskMessage: The final TaskMessage after streaming is complete
"""
if in_temporal_workflow():
params = ChatCompletionStreamAutoSendParams(
trace_id=trace_id,
parent_span_id=parent_span_id,
task_id=task_id,
llm_config=llm_config,
)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION_STREAM_AUTO_SEND,
request=params,
response_type=TaskMessage,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
)
else:
return await self._litellm_service.chat_completion_stream_auto_send(
task_id=task_id,
llm_config=llm_config,
trace_id=trace_id,
parent_span_id=parent_span_id,
)
10 changes: 7 additions & 3 deletions src/agentex/lib/core/adapters/llm/adapter_sgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@


class SGPLLMGateway(LLMGateway):
def __init__(self, sgp_api_key: str | None = None):
self.sync_client = SGPClient(api_key=os.environ.get("SGP_API_KEY", sgp_api_key))
def __init__(self, sgp_api_key: str | None = None, sgp_account_id: str | None = None):
self.sync_client = SGPClient(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The SGP client accepts a variety of configuration args (like number of retries, timeout). Wouldn't those need to be propagated from some llm_config (or at least timeout seems to be a runtime arg)?
  • Nit: maybe mirror the arg names of SGPClient / remove redundant sgp prefix: i.e. sgp_api_key --> api_key

api_key=os.environ.get("SGP_API_KEY", sgp_api_key),
account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id)
)
self.async_client = AsyncSGPClient(
api_key=os.environ.get("SGP_API_KEY", sgp_api_key)
api_key=os.environ.get("SGP_API_KEY", sgp_api_key),
account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id)
)

def completion(self, *args, **kwargs) -> Completion:
Expand Down
Loading