Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/strands/models/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -151,8 +151,8 @@ def __init__(
validate_config_keys(payload_config, self.SageMakerAIPayloadSchema)
payload_config.setdefault("stream", True)
payload_config.setdefault("tool_results_as_user_messages", False)
self.endpoint_config = dict(endpoint_config)
self.payload_config = dict(payload_config)
self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config)
self.payload_config = self.SageMakerAIPayloadSchema(**payload_config)
logger.debug(
"endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i
Returns:
The Amazon SageMaker model configuration.
"""
return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
return self.endpoint_config

@override
def format_request(
Expand Down Expand Up @@ -238,6 +238,10 @@ def format_request(
},
}

payload_additional_args = self.payload_config.get("additional_args")
if payload_additional_args:
payload.update(payload_additional_args)

# Remove tools and tool_choice if tools = []
if not payload["tools"]:
payload.pop("tools")
Expand Down Expand Up @@ -273,16 +277,20 @@ def format_request(
}

# Add optional SageMaker parameters if provided
if self.endpoint_config.get("inference_component_name"):
request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
if self.endpoint_config.get("target_model"):
request["TargetModel"] = self.endpoint_config["target_model"]
if self.endpoint_config.get("target_variant"):
request["TargetVariant"] = self.endpoint_config["target_variant"]

# Add additional args if provided
if self.endpoint_config.get("additional_args"):
request.update(self.endpoint_config["additional_args"].__dict__)
inf_component_name = self.endpoint_config.get("inference_component_name")
if inf_component_name:
request["InferenceComponentName"] = inf_component_name
target_model = self.endpoint_config.get("target_model")
if target_model:
request["TargetModel"] = target_model
target_variant = self.endpoint_config.get("target_variant")
if target_variant:
request["TargetVariant"] = target_variant

# Add additional request args if provided
additional_args = self.endpoint_config.get("additional_args")
if additional_args:
request.update(additional_args)

return request

Expand Down
28 changes: 28 additions & 0 deletions tests/strands/models/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def test_init_with_all_params(self, boto_session):
"endpoint_name": "test-endpoint",
"inference_component_name": "test-component",
"region_name": "us-west-2",
"additional_args": {"test_req_arg_name": "test_req_arg_value"},
}
payload_config = {
"stream": False,
"max_tokens": 1024,
"temperature": 0.7,
"additional_args": {"test_payload_arg_name": "test_payload_arg_value"},
}
client_config = BotocoreConfig(user_agent_extra="test-agent")

Expand All @@ -129,9 +131,11 @@ def test_init_with_all_params(self, boto_session):

assert model.endpoint_config["endpoint_name"] == "test-endpoint"
assert model.endpoint_config["inference_component_name"] == "test-component"
assert model.endpoint_config["additional_args"]["test_req_arg_name"] == "test_req_arg_value"
assert model.payload_config["stream"] is False
assert model.payload_config["max_tokens"] == 1024
assert model.payload_config["temperature"] == 0.7
assert model.payload_config["additional_args"]["test_payload_arg_name"] == "test_payload_arg_value"

boto_session.client.assert_called_once_with(
service_name="sagemaker-runtime",
Expand Down Expand Up @@ -239,6 +243,30 @@ def test_get_config(self, model, endpoint_config):
# assert "tools" in payload
# assert payload["tools"] == []

def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config):
"""Test formatting a request's `additional_args` where provided"""
endpoint_config_ext = {
**endpoint_config,
"additional_args": {
"extra_request_key": "extra_request_value",
},
}
payload_config_ext = {
**payload_config,
"additional_args": {
"extra_payload_key": "extra_payload_value",
},
}
model = SageMakerAIModel(
boto_session=boto_session,
endpoint_config=endpoint_config_ext,
payload_config=payload_config_ext,
)
request = model.format_request(messages)
assert request.get("extra_request_key") == "extra_request_value"
payload = json.loads(request["Body"])
assert payload.get("extra_payload_key") == "extra_payload_value"

@pytest.mark.asyncio
async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages):
"""Test streaming response with streaming enabled."""
Expand Down