diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index d1447732e..25b3ca7ce 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -4,7 +4,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union import boto3 from botocore.config import Config as BotocoreConfig @@ -151,8 +151,8 @@ def __init__( validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) payload_config.setdefault("stream", True) payload_config.setdefault("tool_results_as_user_messages", False) - self.endpoint_config = dict(endpoint_config) - self.payload_config = dict(payload_config) + self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config) + self.payload_config = self.SageMakerAIPayloadSchema(**payload_config) logger.debug( "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config ) @@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i Returns: The Amazon SageMaker model configuration. """ - return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) + return self.endpoint_config @override def format_request( @@ -238,6 +238,10 @@ def format_request( }, } + payload_additional_args = self.payload_config.get("additional_args") + if payload_additional_args: + payload.update(payload_additional_args) + # Remove tools and tool_choice if tools = [] if not payload["tools"]: payload.pop("tools") @@ -273,16 +277,20 @@ def format_request( } # Add optional SageMaker parameters if provided - if self.endpoint_config.get("inference_component_name"): - request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] - if self.endpoint_config.get("target_model"): - request["TargetModel"] = self.endpoint_config["target_model"] - if self.endpoint_config.get("target_variant"): - request["TargetVariant"] = self.endpoint_config["target_variant"] - - # Add additional args if provided - if self.endpoint_config.get("additional_args"): - request.update(self.endpoint_config["additional_args"].__dict__) + inf_component_name = self.endpoint_config.get("inference_component_name") + if inf_component_name: + request["InferenceComponentName"] = inf_component_name + target_model = self.endpoint_config.get("target_model") + if target_model: + request["TargetModel"] = target_model + target_variant = self.endpoint_config.get("target_variant") + if target_variant: + request["TargetVariant"] = target_variant + + # Add additional request args if provided + additional_args = self.endpoint_config.get("additional_args") + if additional_args: + request.update(additional_args) return request diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index a5662ecdc..72ebf01c6 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -112,11 +112,13 @@ def test_init_with_all_params(self, boto_session): "endpoint_name": "test-endpoint", "inference_component_name": "test-component", "region_name": "us-west-2", + "additional_args": {"test_req_arg_name": "test_req_arg_value"}, } payload_config = { "stream": False, "max_tokens": 1024, "temperature": 0.7, + "additional_args": {"test_payload_arg_name": "test_payload_arg_value"}, } client_config = BotocoreConfig(user_agent_extra="test-agent") @@ -129,9 +131,11 @@ def test_init_with_all_params(self, boto_session): assert model.endpoint_config["endpoint_name"] == "test-endpoint" assert model.endpoint_config["inference_component_name"] == "test-component" + assert model.endpoint_config["additional_args"]["test_req_arg_name"] == "test_req_arg_value" assert model.payload_config["stream"] is False assert model.payload_config["max_tokens"] == 1024 assert model.payload_config["temperature"] == 0.7 + assert model.payload_config["additional_args"]["test_payload_arg_name"] == "test_payload_arg_value" boto_session.client.assert_called_once_with( service_name="sagemaker-runtime", @@ -239,6 +243,30 @@ def test_get_config(self, model, endpoint_config): # assert "tools" in payload # assert payload["tools"] == [] + def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config): + """Test formatting a request's `additional_args` where provided""" + endpoint_config_ext = { + **endpoint_config, + "additional_args": { + "extra_request_key": "extra_request_value", + }, + } + payload_config_ext = { + **payload_config, + "additional_args": { + "extra_payload_key": "extra_payload_value", + }, + } + model = SageMakerAIModel( + boto_session=boto_session, + endpoint_config=endpoint_config_ext, + payload_config=payload_config_ext, + ) + request = model.format_request(messages) + assert request.get("extra_request_key") == "extra_request_value" + payload = json.loads(request["Body"]) + assert payload.get("extra_payload_key") == "extra_payload_value" + @pytest.mark.asyncio async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages): """Test streaming response with streaming enabled."""