From dab765edc38814ffd00f3fb7a4b83e3380dbf9ba Mon Sep 17 00:00:00 2001 From: Abhi Shivaditya Date: Tue, 25 Nov 2025 15:57:30 -0800 Subject: [PATCH 1/2] feat(models): add SAP GenAI Hub model provider - Add SAPGenAIHubModel class with support for Nova, Claude, and Titan models - Implement streaming and non-streaming response handling - Add tool calling and structured output support - Include comprehensive unit tests (13 tests, 100% pass rate) - Follow project style guide and conventions The SAP GenAI Hub provider enables Strands Agents to work with models hosted on SAP BTP GenAI Hub, providing access to foundation models through SAP's AI Core infrastructure. Users need: pip install 'generative-ai-hub-sdk[all]' --- src/strands/models/sap_genai_hub.py | 751 +++++++++++++++++++++ tests/strands/models/test_sap_genai_hub.py | 194 ++++++ 2 files changed, 945 insertions(+) create mode 100644 src/strands/models/sap_genai_hub.py create mode 100644 tests/strands/models/test_sap_genai_hub.py diff --git a/src/strands/models/sap_genai_hub.py b/src/strands/models/sap_genai_hub.py new file mode 100644 index 000000000..d39158e6b --- /dev/null +++ b/src/strands/models/sap_genai_hub.py @@ -0,0 +1,751 @@ +"""SAP GenAI Hub model provider. + +- Docs: https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/consume-generative-ai-models-using-sap-ai-core#aws-bedrock +- SDK Reference: https://help.sap.com/doc/generative-ai-hub-sdk/CLOUD/en-US/_reference/gen_ai_hub.html +""" + +import asyncio +import json +import logging +from typing import ( + Any, + AsyncGenerator, + Callable, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, +) + +from pydantic import BaseModel +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import Messages, SystemContentBlock +from ..types.exceptions import ( + ContextWindowOverflowException, + ModelThrottledException, +) +from ..types.streaming import StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys +from .model import Model + +# Import SAP GenAI Hub SDK +try: + from gen_ai_hub.proxy.native.amazon.clients import Session +except ImportError as e: + raise ImportError( + "SAP GenAI Hub SDK is not installed. Please install it with: pip install 'generative-ai-hub-sdk[all]'" + ) from e + +logger = logging.getLogger(__name__) + +DEFAULT_SAP_GENAI_HUB_MODEL_ID = "amazon--nova-lite" + +# Common error messages for context window overflow +CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", +] + +T = TypeVar("T", bound=BaseModel) + + +class SAPGenAIHubModel(Model): + """SAP GenAI Hub model provider implementation. + + This implementation handles SAP GenAI Hub-specific features such as: + - Tool configuration for function calling + - Streaming responses + - Context window overflow detection + - Support for different model types (Nova, Claude, Titan) + """ + + class SAPGenAIHubConfig(TypedDict, total=False): + """Configuration options for SAP GenAI Hub models. + + Attributes: + additional_args: Any additional arguments to include in the request + max_tokens: Maximum number of tokens to generate in the response + model_id: The SAP GenAI Hub model ID (e.g., "amazon--nova-lite", "anthropic--claude-3-sonnet") + stop_sequences: List of sequences that will stop generation when encountered + streaming: Flag to enable/disable streaming. Defaults to True. + temperature: Controls randomness in generation (higher = more random) + top_p: Controls diversity via nucleus sampling (alternative to temperature) + """ + + additional_args: Optional[dict[str, Any]] + max_tokens: Optional[int] + model_id: str + stop_sequences: Optional[list[str]] + streaming: Optional[bool] + temperature: Optional[float] + top_p: Optional[float] + + def __init__( + self, + **model_config: Unpack[SAPGenAIHubConfig], + ): + """Initialize provider instance. + + Args: + **model_config: Configuration options for the SAP GenAI Hub model. + """ + self.config = SAPGenAIHubModel.SAPGenAIHubConfig(model_id=DEFAULT_SAP_GENAI_HUB_MODEL_ID) + self.update_config(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + + # Initialize the SAP GenAI Hub client + self.client = Session().client(model_name=self.config["model_id"]) + + @override + def update_config(self, **model_config: Unpack[SAPGenAIHubConfig]) -> None: # type: ignore + """Update the SAP GenAI Hub Model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.SAPGenAIHubConfig) + self.config.update(model_config) + + @override + def get_config(self) -> SAPGenAIHubConfig: + """Get the current SAP GenAI Hub Model configuration. + + Returns: + The SAP GenAI Hub model configuration. + """ + return self.config + + def _is_nova_model(self) -> bool: + """Check if the current model is an Amazon Nova model. + + Returns: + True if the model is an Amazon Nova model, False otherwise. + """ + nova_models = ["amazon--nova-pro", "amazon--nova-micro", "amazon--nova-lite"] + return self.config["model_id"] in nova_models + + def _is_claude_model(self) -> bool: + """Check if the current model is an Anthropic Claude model. + + Returns: + True if the model is an Anthropic Claude model, False otherwise. + """ + return self.config["model_id"].startswith("anthropic--claude") + + def _is_titan_embed_model(self) -> bool: + """Check if the current model is an Amazon Titan Embedding model. + + Returns: + True if the model is an Amazon Titan Embedding model, False otherwise. + """ + return self.config["model_id"] == "amazon--titan-embed-text" + + def _format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format a request for the SAP GenAI Hub model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Returns: + A formatted request for the SAP GenAI Hub model. + """ + # Format request based on model type + if self._is_nova_model(): + return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice) + elif self._is_claude_model(): + return self._format_claude_request(messages, tool_specs, system_prompt_content, tool_choice) + elif self._is_titan_embed_model(): + return self._format_titan_embed_request(messages) + else: + raise ValueError(f"model_id=<{self.config['model_id']}> | unsupported model") + + def _format_nova_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format a request for Amazon Nova models. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Returns: + A formatted request for Amazon Nova models. + """ + request: dict[str, Any] = { + "messages": messages, + "inferenceConfig": { + key: value + for key, value in [ + ("maxTokens", self.config.get("max_tokens")), + ("temperature", self.config.get("temperature")), + ("topP", self.config.get("top_p")), + ("stopSequences", self.config.get("stop_sequences")), + ] + if value is not None + }, + } + + # Add system prompt if provided + if system_prompt_content: + request["system"] = system_prompt_content + + # Add tool specs if provided + if tool_specs: + request["toolConfig"] = { + "tools": [{"toolSpec": tool_spec} for tool_spec in tool_specs], + "toolChoice": tool_choice if tool_choice else {"auto": {}}, + } + + # Add additional arguments if provided + if "additional_args" in self.config and self.config["additional_args"] is not None: + request.update(self.config["additional_args"]) + + return request + + def _format_claude_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format a request for Anthropic Claude models. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Returns: + A formatted request for Anthropic Claude models. + """ + # For Claude models, we use the same format as Nova models + # since we're using the converse API for both + return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice) + + def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]: + """Format a request for Amazon Titan Embedding models. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + A formatted request for Amazon Titan Embedding models. + """ + # Extract the text from the last user message + input_text = "" + for message in reversed(messages): + if message["role"] == "user" and "content" in message: + content_blocks = message["content"] + if isinstance(content_blocks, list): + for block in content_blocks: + if "text" in block: + input_text = block["text"] + break + if input_text: + break + + request: dict[str, Any] = { + "inputText": input_text, + } + + # Add additional arguments if provided + if "additional_args" in self.config and self.config["additional_args"] is not None: + request.update(self.config["additional_args"]) + + return request + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format the SAP GenAI Hub response events into standardized message chunks. + + Args: + event: A response event from the SAP GenAI Hub model. + + Returns: + The formatted chunk. + """ + # Handle string events by wrapping them in a proper structure + if isinstance(event, str): + return {"contentBlockDelta": {"delta": {"text": event}}} + + # If it's already a proper dictionary, return it as is + if isinstance(event, dict): + return cast(StreamEvent, event) + + # For any other type, convert to string and wrap + return {"contentBlockDelta": {"delta": {"text": str(event)}}} + + def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEvent]: + """Convert a streaming response to the standardized streaming format. + + Args: + stream_response: The streaming response from the SAP GenAI Hub model. + + Returns: + An iterable of response events in the streaming format. + """ + try: + logger.debug( + "stream_type=<%s> | converting streaming response", + type(stream_response).__name__, + ) + + message_started = False + event_count = 0 + + # Check if it's a Bedrock-style response with 'stream' key + if hasattr(stream_response, "get") and callable(stream_response.get): + event_stream = stream_response.get("stream") + if event_stream: + logger.debug("processing bedrock-style event stream") + + for event in event_stream: + event_count += 1 + + # Start message if not started + if not message_started: + yield {"messageStart": {"role": "assistant"}} + message_started = True + + # Process the event based on its structure + if isinstance(event, dict): + # Pass through properly formatted events + if any( + key in event + for key in [ + "contentBlockDelta", + "contentBlockStart", + "contentBlockStop", + "messageStart", + "messageStop", + "metadata", + ] + ): + yield event + else: + # Format unknown events + yield self._format_chunk(event) + else: + # Handle non-dict events (strings, etc.) + yield self._format_chunk(event) + + logger.debug( + "event_count=<%d> | processed bedrock stream events", + event_count, + ) + return + + # Try to iterate directly over the stream_response + try: + if hasattr(stream_response, "__iter__"): + logger.debug("processing iterable stream response") + + for event in stream_response: + event_count += 1 + + # Start message if not started + if not message_started: + yield {"messageStart": {"role": "assistant"}} + message_started = True + + # Process the event + if isinstance(event, dict): + # Check if it's already a properly formatted streaming event + if any( + key in event + for key in [ + "contentBlockDelta", + "contentBlockStart", + "contentBlockStop", + "messageStart", + "messageStop", + "metadata", + ] + ): + yield event + + # If this is a messageStop event, we're done + if "messageStop" in event: + logger.debug("received messageStop event from stream") + return + else: + # Format unknown events + yield self._format_chunk(event) + else: + # Handle strings and other types + yield self._format_chunk(event) + + logger.debug( + "event_count=<%d> | processed direct iteration events", + event_count, + ) + else: + logger.debug("stream response not iterable, treating as single response") + yield {"messageStart": {"role": "assistant"}} + yield self._format_chunk(stream_response) + + except TypeError as te: + # stream_response is not iterable + logger.debug( + "error=<%s> | stream response not iterable, treating as single response", + te, + ) + yield {"messageStart": {"role": "assistant"}} + yield self._format_chunk(stream_response) + + except Exception as e: + logger.error( + "error=<%s>, stream_type=<%s> | error processing streaming response", + e, + type(stream_response), + ) + raise e + + async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + """Convert a non-streaming response to the streaming format. + + Args: + response: The non-streaming response from the SAP GenAI Hub model. + + Yields: + Response events in the streaming format. + """ + if self._is_nova_model() or self._is_claude_model(): + # Nova and Claude models have a similar response format when using converse API + # Yield messageStart event + yield {"messageStart": {"role": response["output"]["message"]["role"]}} + + # Process content blocks + for content in response["output"]["message"]["content"]: + # Yield contentBlockStart event if needed + if "toolUse" in content: + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], + } + }, + } + } + + # For tool use, we need to yield the input as a delta + input_value = json.dumps(content["toolUse"]["input"]) + + yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + elif "text" in content: + # Then yield the text as a delta + yield { + "contentBlockDelta": { + "delta": {"text": content["text"]}, + } + } + + # Yield contentBlockStop event + yield {"contentBlockStop": {}} + + # Yield messageStop event + yield { + "messageStop": { + "stopReason": response.get("stopReason", "stop"), + } + } + + # Yield metadata event + if "usage" in response or "metrics" in response: + metadata: StreamEvent = {"metadata": {}} + if "usage" in response: + metadata["metadata"]["usage"] = response["usage"] + if "metrics" in response: + metadata["metadata"]["metrics"] = response["metrics"] + yield metadata + + elif self._is_titan_embed_model(): + # Titan Embedding models have a different response format + # Yield messageStart event + yield {"messageStart": {"role": "assistant"}} + + # Yield content block for embedding + if "embedding" in response: + yield { + "contentBlockDelta": { + "delta": {"text": f"Embedding generated with {len(response['embedding'])} dimensions"}, + } + } + + # Yield contentBlockStop event + yield {"contentBlockStop": {}} + + # Yield messageStop event + yield { + "messageStop": { + "stopReason": "stop", + } + } + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Send the request to the SAP GenAI Hub model and get the response. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments. + + Yields: + Response events from the SAP GenAI Hub model + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the model service is throttling requests. + """ + + def callback(event: Optional[StreamEvent] = None) -> None: + loop.call_soon_threadsafe(queue.put_nowait, event) + + loop = asyncio.get_event_loop() + queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + thread = asyncio.to_thread( + self._stream, + callback, + messages, + tool_specs, + system_prompt_content, + tool_choice, + ) + task = asyncio.create_task(thread) + + while True: + event = await queue.get() + if event is None: + break + + yield event + + await task + + def _stream( + self, + callback: Callable[..., None], + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: ToolChoice | None = None, + ) -> None: + """Stream conversation with the SAP GenAI Hub model. + + This method operates in a separate thread to avoid blocking the async event loop. + + Args: + callback: Function to send events to the main thread. + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the model service is throttling requests. + """ + try: + logger.debug("formatting request") + request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) + + logger.debug("invoking model") + streaming = self.config.get("streaming", True) + + if self._is_nova_model() or self._is_claude_model(): + if streaming: + # Use converse_stream for streaming responses + try: + logger.debug("using converse_stream api") + stream_response = self.client.converse_stream(**request) + + # Process all streaming events + event_count = 0 + has_content = False + + for event in self._convert_streaming_response(stream_response): + event_count += 1 + + # Check if we have actual content + if "contentBlockDelta" in event: + has_content = True + + callback(event) + + logger.debug( + "event_count=<%d>, has_content=<%s> | processed streaming events", + event_count, + has_content, + ) + + # If we didn't get any content, fallback to non-streaming + if event_count == 0 or not has_content: + logger.warning("no content received from streaming, falling back to non-streaming") + response = self.client.converse(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + + except (AttributeError, Exception) as e: + # Fallback to non-streaming if converse_stream fails + logger.warning( + "error=<%s> | converse_stream failed, falling back to non-streaming", + e, + ) + response = self.client.converse(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + else: + # Non-streaming path + logger.debug("using non-streaming converse api") + response = self.client.converse(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + + elif self._is_titan_embed_model(): + if streaming: + # Try streaming for Titan models + try: + logger.debug("using invoke_model_with_response_stream for titan") + stream_response = self.client.invoke_model_with_response_stream(**request) + + event_count = 0 + for event in self._convert_streaming_response(stream_response): + event_count += 1 + callback(event) + + if event_count == 0: + logger.warning("no events from titan streaming, falling back to non-streaming") + response = self.client.invoke_model(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + + except (AttributeError, Exception) as e: + logger.warning( + "error=<%s> | titan streaming failed, falling back to non-streaming", + e, + ) + response = self.client.invoke_model(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + else: + # Non-streaming path for Titan + logger.debug("using non-streaming invoke_model api for titan") + response = self.client.invoke_model(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + + except Exception as e: + error_message = str(e) + + # Handle throttling error + if "ThrottlingException" in error_message: + raise ModelThrottledException(error_message) from e + + # Handle context window overflow + if any(overflow_message in error_message for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES): + logger.warning("sap genai hub threw context window overflow error") + raise ContextWindowOverflowException(e) from e + + # Otherwise raise the error + raise e + + finally: + callback() + logger.debug("finished streaming response from model") + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + """ + from ..event_loop import streaming + from ..tools import convert_pydantic_to_tool_spec + + # Create a tool spec from the schema + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), + **kwargs, + ) + async for event in streaming.process_stream(response): + yield event + + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the response.") + + yield {"output": output_model(**output_response)} diff --git a/tests/strands/models/test_sap_genai_hub.py b/tests/strands/models/test_sap_genai_hub.py new file mode 100644 index 000000000..43c62a05f --- /dev/null +++ b/tests/strands/models/test_sap_genai_hub.py @@ -0,0 +1,194 @@ +"""Tests for SAP GenAI Hub model provider.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.models.sap_genai_hub import SAPGenAIHubModel + + +class TestSAPGenAIHubModel: + """Test suite for SAP GenAI Hub model provider.""" + + def test_initialization_with_default_config(self): + """Test model initialization with default configuration.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel() + + assert model.config["model_id"] == "amazon--nova-lite" + mock_session.return_value.client.assert_called_once_with(model_name="amazon--nova-lite") + + def test_initialization_with_custom_config(self): + """Test model initialization with custom configuration.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel( + model_id="anthropic--claude-3-sonnet", + temperature=0.7, + max_tokens=1000, + ) + + assert model.config["model_id"] == "anthropic--claude-3-sonnet" + assert model.config["temperature"] == 0.7 + assert model.config["max_tokens"] == 1000 + + def test_update_config(self): + """Test updating model configuration.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel() + model.update_config(temperature=0.5, max_tokens=2000) + + assert model.config["temperature"] == 0.5 + assert model.config["max_tokens"] == 2000 + + def test_get_config(self): + """Test getting model configuration.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(temperature=0.8) + config = model.get_config() + + assert config["temperature"] == 0.8 + assert "model_id" in config + + def test_is_nova_model(self): + """Test Nova model detection.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + # Test Nova models + for model_id in [ + "amazon--nova-pro", + "amazon--nova-micro", + "amazon--nova-lite", + ]: + model = SAPGenAIHubModel(model_id=model_id) + assert model._is_nova_model() is True + assert model._is_claude_model() is False + assert model._is_titan_embed_model() is False + + def test_is_claude_model(self): + """Test Claude model detection.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="anthropic--claude-3-sonnet") + assert model._is_claude_model() is True + assert model._is_nova_model() is False + assert model._is_titan_embed_model() is False + + def test_is_titan_embed_model(self): + """Test Titan Embedding model detection.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="amazon--titan-embed-text") + assert model._is_titan_embed_model() is True + assert model._is_nova_model() is False + assert model._is_claude_model() is False + + def test_format_nova_request(self): + """Test request formatting for Nova models.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="amazon--nova-lite", temperature=0.7, max_tokens=1000) + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant"}] + + request = model._format_nova_request(messages=messages, system_prompt_content=system_prompt_content) + + assert request["messages"] == messages + assert request["system"] == system_prompt_content + assert request["inferenceConfig"]["temperature"] == 0.7 + assert request["inferenceConfig"]["maxTokens"] == 1000 + + def test_format_nova_request_with_tools(self): + """Test request formatting for Nova models with tools.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="amazon--nova-lite") + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"type": "object", "properties": {}}, + } + ] + + request = model._format_nova_request(messages=messages, tool_specs=tool_specs) + + assert "toolConfig" in request + assert len(request["toolConfig"]["tools"]) == 1 + assert request["toolConfig"]["tools"][0]["toolSpec"]["name"] == "test_tool" + + def test_format_titan_embed_request(self): + """Test request formatting for Titan Embedding models.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="amazon--titan-embed-text") + + messages = [ + {"role": "user", "content": [{"text": "Text to embed"}]}, + ] + + request = model._format_titan_embed_request(messages) + + assert request["inputText"] == "Text to embed" + + def test_unsupported_model(self): + """Test that unsupported model raises ValueError.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel(model_id="unsupported--model") + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + with pytest.raises(ValueError, match="unsupported model"): + model._format_request(messages) + + def test_format_chunk_with_string(self): + """Test chunk formatting with string input.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel() + chunk = model._format_chunk("test text") + + assert chunk == {"contentBlockDelta": {"delta": {"text": "test text"}}} + + def test_format_chunk_with_dict(self): + """Test chunk formatting with dict input.""" + with patch("strands.models.sap_genai_hub.Session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + + model = SAPGenAIHubModel() + input_chunk = {"contentBlockDelta": {"delta": {"text": "test"}}} + chunk = model._format_chunk(input_chunk) + + assert chunk == input_chunk From a16f2cc031a768d3b140e5d2ae14cca0a1d5a02d Mon Sep 17 00:00:00 2001 From: Abhi Shivaditya Date: Wed, 26 Nov 2025 19:06:28 -0800 Subject: [PATCH 2/2] tested and working --- pyproject.toml | 5 +- src/strands/models/sap_genai_hub.py | 163 ++++++---- tests/strands/models/test_sap_genai_hub.py | 16 +- tests_integ/models/providers.py | 53 +++- .../models/test_model_sap_genai_hub.py | 297 ++++++++++++++++++ 5 files changed, 470 insertions(+), 64 deletions(-) create mode 100644 tests_integ/models/test_model_sap_genai_hub.py diff --git a/pyproject.toml b/pyproject.toml index b542c7481..1e7e8f094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,9 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +sap_genai_hub = [ + "sap-ai-sdk-gen[all]>=5.0.0,<6.0.0", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<9.0.0", @@ -69,7 +72,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,sap_genai_hub,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/models/sap_genai_hub.py b/src/strands/models/sap_genai_hub.py index d39158e6b..348d61571 100644 --- a/src/strands/models/sap_genai_hub.py +++ b/src/strands/models/sap_genai_hub.py @@ -1,7 +1,7 @@ """SAP GenAI Hub model provider. - Docs: https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/consume-generative-ai-models-using-sap-ai-core#aws-bedrock -- SDK Reference: https://help.sap.com/doc/generative-ai-hub-sdk/CLOUD/en-US/_reference/gen_ai_hub.html +- SDK Reference: https://help.sap.com/doc/sap-ai-sdk-gen/CLOUD/en-US/_reference/gen_ai_hub.html """ import asyncio @@ -37,7 +37,7 @@ from gen_ai_hub.proxy.native.amazon.clients import Session except ImportError as e: raise ImportError( - "SAP GenAI Hub SDK is not installed. Please install it with: pip install 'generative-ai-hub-sdk[all]'" + "SAP GenAI Hub SDK is not installed. Please install it with: pip install 'sap-ai-sdk-gen[all]'" ) from e logger = logging.getLogger(__name__) @@ -94,7 +94,9 @@ def __init__( Args: **model_config: Configuration options for the SAP GenAI Hub model. """ - self.config = SAPGenAIHubModel.SAPGenAIHubConfig(model_id=DEFAULT_SAP_GENAI_HUB_MODEL_ID) + self.config = SAPGenAIHubModel.SAPGenAIHubConfig( + model_id=DEFAULT_SAP_GENAI_HUB_MODEL_ID + ) self.update_config(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -166,13 +168,19 @@ def _format_request( """ # Format request based on model type if self._is_nova_model(): - return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice) + return self._format_nova_request( + messages, tool_specs, system_prompt_content, tool_choice + ) elif self._is_claude_model(): - return self._format_claude_request(messages, tool_specs, system_prompt_content, tool_choice) + return self._format_claude_request( + messages, tool_specs, system_prompt_content, tool_choice + ) elif self._is_titan_embed_model(): return self._format_titan_embed_request(messages) else: - raise ValueError(f"model_id=<{self.config['model_id']}> | unsupported model") + raise ValueError( + f"model_id=<{self.config['model_id']}> | unsupported model" + ) def _format_nova_request( self, @@ -218,7 +226,10 @@ def _format_nova_request( } # Add additional arguments if provided - if "additional_args" in self.config and self.config["additional_args"] is not None: + if ( + "additional_args" in self.config + and self.config["additional_args"] is not None + ): request.update(self.config["additional_args"]) return request @@ -243,7 +254,9 @@ def _format_claude_request( """ # For Claude models, we use the same format as Nova models # since we're using the converse API for both - return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice) + return self._format_nova_request( + messages, tool_specs, system_prompt_content, tool_choice + ) def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]: """Format a request for Amazon Titan Embedding models. @@ -272,7 +285,10 @@ def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]: } # Add additional arguments if provided - if "additional_args" in self.config and self.config["additional_args"] is not None: + if ( + "additional_args" in self.config + and self.config["additional_args"] is not None + ): request.update(self.config["additional_args"]) return request @@ -297,7 +313,9 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: # For any other type, convert to string and wrap return {"contentBlockDelta": {"delta": {"text": str(event)}}} - def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEvent]: + def _convert_streaming_response( + self, stream_response: Any + ) -> Iterable[StreamEvent]: """Convert a streaming response to the standardized streaming format. Args: @@ -388,7 +406,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv # If this is a messageStop event, we're done if "messageStop" in event: - logger.debug("received messageStop event from stream") + logger.debug( + "received messageStop event from stream" + ) return else: # Format unknown events @@ -402,7 +422,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv event_count, ) else: - logger.debug("stream response not iterable, treating as single response") + logger.debug( + "stream response not iterable, treating as single response" + ) yield {"messageStart": {"role": "assistant"}} yield self._format_chunk(stream_response) @@ -423,7 +445,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv ) raise e - async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + def _convert_non_streaming_to_streaming( + self, response: dict[str, Any] + ) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. Args: @@ -455,7 +479,11 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> # For tool use, we need to yield the input as a delta input_value = json.dumps(content["toolUse"]["input"]) - yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + yield { + "contentBlockDelta": { + "delta": {"toolUse": {"input": input_value}} + } + } elif "text" in content: # Then yield the text as a delta yield { @@ -492,7 +520,9 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> if "embedding" in response: yield { "contentBlockDelta": { - "delta": {"text": f"Embedding generated with {len(response['embedding'])} dimensions"}, + "delta": { + "text": f"Embedding generated with {len(response['embedding'])} dimensions" + }, } } @@ -589,56 +619,60 @@ def _stream( """ try: logger.debug("formatting request") - request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) + request = self._format_request( + messages, tool_specs, system_prompt_content, tool_choice + ) logger.debug("invoking model") streaming = self.config.get("streaming", True) if self._is_nova_model() or self._is_claude_model(): - if streaming: - # Use converse_stream for streaming responses - try: - logger.debug("using converse_stream api") - stream_response = self.client.converse_stream(**request) - - # Process all streaming events - event_count = 0 - has_content = False + # Try converse_stream first, fall back to converse if not supported + try: + logger.debug("attempting converse_stream api") + stream_response = self.client.converse_stream(**request) - for event in self._convert_streaming_response(stream_response): - event_count += 1 + # Process all streaming events + event_count = 0 + has_content = False - # Check if we have actual content - if "contentBlockDelta" in event: - has_content = True + for event in self._convert_streaming_response(stream_response): + event_count += 1 - callback(event) + # Check if we have actual content + if "contentBlockDelta" in event: + has_content = True - logger.debug( - "event_count=<%d>, has_content=<%s> | processed streaming events", - event_count, - has_content, - ) + callback(event) - # If we didn't get any content, fallback to non-streaming - if event_count == 0 or not has_content: - logger.warning("no content received from streaming, falling back to non-streaming") - response = self.client.converse(**request) - for event in self._convert_non_streaming_to_streaming(response): - callback(event) + logger.debug( + "event_count=<%d>, has_content=<%s> | processed streaming events", + event_count, + has_content, + ) - except (AttributeError, Exception) as e: - # Fallback to non-streaming if converse_stream fails - logger.warning( - "error=<%s> | converse_stream failed, falling back to non-streaming", - e, + # If we didn't get any content, fallback to non-streaming + if event_count == 0 or not has_content: + logger.debug( + "no content received from streaming, falling back to converse" ) response = self.client.converse(**request) for event in self._convert_non_streaming_to_streaming(response): callback(event) - else: - # Non-streaming path - logger.debug("using non-streaming converse api") + + except NotImplementedError as nie: + # converse_stream not supported by this model/deployment, use converse + logger.debug("converse_stream not supported, using converse api") + response = self.client.converse(**request) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) + + except Exception as e: + # Other errors - log and fallback to converse + logger.debug( + "error=<%s> | converse_stream failed, falling back to converse", + e, + ) response = self.client.converse(**request) for event in self._convert_non_streaming_to_streaming(response): callback(event) @@ -647,8 +681,12 @@ def _stream( if streaming: # Try streaming for Titan models try: - logger.debug("using invoke_model_with_response_stream for titan") - stream_response = self.client.invoke_model_with_response_stream(**request) + logger.debug( + "using invoke_model_with_response_stream for titan" + ) + stream_response = self.client.invoke_model_with_response_stream( + **request + ) event_count = 0 for event in self._convert_streaming_response(stream_response): @@ -656,9 +694,13 @@ def _stream( callback(event) if event_count == 0: - logger.warning("no events from titan streaming, falling back to non-streaming") + logger.warning( + "no events from titan streaming, falling back to non-streaming" + ) response = self.client.invoke_model(**request) - for event in self._convert_non_streaming_to_streaming(response): + for event in self._convert_non_streaming_to_streaming( + response + ): callback(event) except (AttributeError, Exception) as e: @@ -684,7 +726,10 @@ def _stream( raise ModelThrottledException(error_message) from e # Handle context window overflow - if any(overflow_message in error_message for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES): + if any( + overflow_message in error_message + for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES + ): logger.warning("sap genai hub threw context window overflow error") raise ContextWindowOverflowException(e) from e @@ -733,7 +778,9 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') + raise ValueError( + f'Model returned stop_reason: {stop_reason} instead of "tool_use".' + ) content = messages["content"] output_response: dict[str, Any] | None = None @@ -746,6 +793,8 @@ async def structured_output( continue if output_response is None: - raise ValueError("No valid tool use or tool use input was found in the response.") + raise ValueError( + "No valid tool use or tool use input was found in the response." + ) yield {"output": output_model(**output_response)} diff --git a/tests/strands/models/test_sap_genai_hub.py b/tests/strands/models/test_sap_genai_hub.py index 43c62a05f..cc3144dec 100644 --- a/tests/strands/models/test_sap_genai_hub.py +++ b/tests/strands/models/test_sap_genai_hub.py @@ -19,7 +19,9 @@ def test_initialization_with_default_config(self): model = SAPGenAIHubModel() assert model.config["model_id"] == "amazon--nova-lite" - mock_session.return_value.client.assert_called_once_with(model_name="amazon--nova-lite") + mock_session.return_value.client.assert_called_once_with( + model_name="amazon--nova-lite" + ) def test_initialization_with_custom_config(self): """Test model initialization with custom configuration.""" @@ -106,12 +108,16 @@ def test_format_nova_request(self): mock_client = MagicMock() mock_session.return_value.client.return_value = mock_client - model = SAPGenAIHubModel(model_id="amazon--nova-lite", temperature=0.7, max_tokens=1000) + model = SAPGenAIHubModel( + model_id="amazon--nova-lite", temperature=0.7, max_tokens=1000 + ) messages = [{"role": "user", "content": [{"text": "Hello"}]}] system_prompt_content = [{"text": "You are a helpful assistant"}] - request = model._format_nova_request(messages=messages, system_prompt_content=system_prompt_content) + request = model._format_nova_request( + messages=messages, system_prompt_content=system_prompt_content + ) assert request["messages"] == messages assert request["system"] == system_prompt_content @@ -135,7 +141,9 @@ def test_format_nova_request_with_tools(self): } ] - request = model._format_nova_request(messages=messages, tool_specs=tool_specs) + request = model._format_nova_request( + messages=messages, tool_specs=tool_specs + ) assert "toolConfig" in request assert len(request["toolConfig"]["tools"]) == 1 diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..269dcc718 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -16,6 +16,7 @@ from strands.models.mistral import MistralModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel +from strands.models.sap_genai_hub import SAPGenAIHubModel from strands.models.writer import WriterModel @@ -44,7 +45,10 @@ class OllamaProviderInfo(ProviderInfo): def __init__(self): super().__init__( - id="ollama", factory=lambda: OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + id="ollama", + factory=lambda: OllamaModel( + host="http://localhost:11434", model_id="llama3.3:70b" + ), ) is_server_available = False @@ -59,6 +63,46 @@ def __init__(self): ) +class SAPGenAIHubProviderInfo(ProviderInfo): + """Special case SAP GenAI Hub as it requires AI Core credentials to be configured.""" + + def __init__(self): + super().__init__( + id="sap_genai_hub", + factory=lambda: SAPGenAIHubModel(model_id="amazon--nova-lite"), + ) + + credentials_available = self._check_sap_credentials_available() + self.mark = mark.skipif( + not credentials_available, + reason="SAP AI Core credentials not available - configure service key or set environment variables", + ) + + def _check_sap_credentials_available(self) -> bool: + """Check if SAP GenAI Hub credentials are available.""" + try: + # Try to import the SAP GenAI Hub SDK + from gen_ai_hub.proxy.native.amazon.clients import Session + + # Try to create a session - this will fail if credentials are missing + session = Session() + # Try to create a client - this is where it would fail with missing base_url + client = session.client(model_name="amazon--nova-lite") + # If we got this far, credentials are available + return True + except (ImportError, TypeError, Exception) as e: + # Common errors when credentials are missing: + # - TypeError: AICoreV2Client.__init__() missing 1 required positional argument: 'base_url' + # - ImportError: No module named 'gen_ai_hub' + # - Other configuration-related exceptions + error_msg = str(e) + if "missing 1 required positional argument: 'base_url'" in error_msg: + # This is the specific error we expect when AI Core credentials are missing + return False + # For any other errors, also assume credentials not available + return False + + anthropic = ProviderInfo( id="anthropic", environment_variable="ANTHROPIC_API_KEY", @@ -84,7 +128,10 @@ def __init__(self): ), ) litellm = ProviderInfo( - id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + id="litellm", + factory=lambda: LiteLLMModel( + model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0" + ), ) llama = ProviderInfo( id="llama", @@ -138,6 +185,7 @@ def __init__(self): ) ollama = OllamaProviderInfo() +sap_genai_hub = SAPGenAIHubProviderInfo() all_providers = [ @@ -149,5 +197,6 @@ def __init__(self): litellm, mistral, openai, + sap_genai_hub, writer, ] diff --git a/tests_integ/models/test_model_sap_genai_hub.py b/tests_integ/models/test_model_sap_genai_hub.py new file mode 100644 index 000000000..be2cb831a --- /dev/null +++ b/tests_integ/models/test_model_sap_genai_hub.py @@ -0,0 +1,297 @@ +import os + +import pydantic +import pytest + +import strands +from strands import Agent +from strands.models.sap_genai_hub import SAPGenAIHubModel +from tests_integ.models import providers + +# these tests only run if we have the SAP AI Core credentials +pytestmark = providers.sap_genai_hub.mark + + +@pytest.fixture +def model(): + return SAPGenAIHubModel( + model_id="amazon--nova-lite", + temperature=0.15, # Lower temperature for consistent test behavior + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are a helpful AI assistant." + + +@pytest.fixture +def assistant_agent(model, system_prompt): + return Agent(model=model, system_prompt=system_prompt) + + +@pytest.fixture +def tool_agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.fixture +def weather(): + class Weather(pydantic.BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +@pytest.fixture +def yellow_color(): + class Color(pydantic.BaseModel): + """Describes a color.""" + + name: str + + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + + return Color(name="yellow") + + +def test_agent_invoke(tool_agent): + result = tool_agent("What is the current time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(tool_agent): + result = await tool_agent.invoke_async( + "What is the current time and weather in New York?" + ) + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(tool_agent): + stream = tool_agent.stream_async( + "What is the current time and weather in New York?" + ) + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_invoke_multiturn(assistant_agent): + assistant_agent("What color is the sky?") + assistant_agent("What color is lava?") + result = assistant_agent("What was the answer to my first question?") + text = result.message["content"][0]["text"].lower() + + assert "blue" in text + + +def test_agent_invoke_image_input(assistant_agent, yellow_img): + content = [ + {"text": "what is in this image"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + result = assistant_agent(content) + text = result.message["content"][0]["text"].lower() + + assert "yellow" in text + + +def test_agent_invoke_document_input(assistant_agent, letter_pdf): + content = [ + {"text": "summarize this document"}, + { + "document": { + "format": "pdf", + "name": "letter_name", + "source": {"bytes": letter_pdf}, + } + }, + ] + result = assistant_agent(content) + text = result.message["content"][0]["text"].lower() + + assert "shareholder" in text + + +def test_agent_structured_output(assistant_agent, weather): + tru_weather = assistant_agent.structured_output( + type(weather), "The time is 12:00 and the weather is sunny" + ) + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(assistant_agent, weather): + tru_weather = await assistant_agent.structured_output_async( + type(weather), "The time is 12:00 and the weather is sunny" + ) + exp_weather = weather + assert tru_weather == exp_weather + + +def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow_color): + content = [ + {"text": "Is this image red, blue, or yellow?"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + tru_color = assistant_agent.structured_output(type(yellow_color), content) + exp_color = yellow_color + assert tru_color == exp_color + + +@pytest.mark.parametrize( + "model_id", + [ + "anthropic--claude-4.5-sonnet", + "anthropic--claude-4-sonnet", + "anthropic--claude-3.7-sonnet", + "anthropic--claude-3.5-sonnet", + "anthropic--claude-3-haiku", + "amazon--nova-pro", + "amazon--nova-lite", + "amazon--nova-micro", + ], +) +def test_different_models(model_id): + """Test various SAP GenAI Hub models.""" + model = SAPGenAIHubModel(model_id=model_id) + agent = Agent( + model=model, + system_prompt="You are a helpful assistant. Keep responses very brief.", + ) + + result = agent("What is 2+2?") + text = result.message["content"][0]["text"] + + assert "4" in text + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_id", + [ + "anthropic--claude-4.5-sonnet", + "anthropic--claude-4-sonnet", + "anthropic--claude-3.7-sonnet", + "anthropic--claude-3.5-sonnet", + "anthropic--claude-3-haiku", + "amazon--nova-pro", + "amazon--nova-lite", + "amazon--nova-micro", + ], +) +async def test_streaming_for_models(model_id): + """Test streaming for various SAP GenAI Hub models.""" + model = SAPGenAIHubModel(model_id=model_id) + agent = Agent(model=model) + + chunk_count = 0 + stream = agent.stream_async("Count from 1 to 5.") + + async for event in stream: + if "data" in event and event["data"]: + chunk_count += 1 + + assert chunk_count > 0 + + +def test_multi_agent_workflow(): + """Test multi-agent workflow using agents as tools pattern.""" + from textwrap import dedent + + nova_model = SAPGenAIHubModel(model_id="amazon--nova-lite") + claude_model = SAPGenAIHubModel(model_id="anthropic--claude-3.5-sonnet") + + @strands.tool + def research_assistant(query: str) -> str: + """Research assistant that provides factual information.""" + research_agent = Agent( + model=claude_model, + system_prompt=dedent( + """You are a specialized research assistant. Focus only on providing + factual information. Keep responses brief and to the point.""" + ), + ) + return research_agent(query).message + + @strands.tool + def creative_writing_assistant(query: str) -> str: + """Creative writing assistant that generates creative content.""" + creative_agent = Agent( + model=nova_model, + system_prompt=dedent( + """You are a specialized creative writing assistant. + Create engaging content. Keep responses brief and focused.""" + ), + ) + return creative_agent(query).message + + orchestrator = Agent( + model=nova_model, + system_prompt="""You are an assistant that routes queries to specialized agents: +- For research questions use research_assistant +- For creative writing use creative_writing_assistant +- For simple questions answer directly""", + tools=[research_assistant, creative_writing_assistant], + ) + + result = orchestrator("What is quantum computing? (1 sentence)") + text = result.message["content"][0]["text"].lower() + + assert "quantum" in text or "computing" in text + + +def test_model_with_custom_parameters(): + """Test model with custom parameters.""" + model = SAPGenAIHubModel( + model_id="amazon--nova-lite", temperature=0.3, top_p=0.9, max_tokens=50 + ) + agent = Agent(model=model) + + result = agent("Count from 1 to 5.") + + assert len(result.message["content"][0]["text"]) > 0