generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 456
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Problem Statement
Support OpenAI API compatible vLLM/SGLang server behind SigV4 proxy server
- I was trying to work with strands agent and openai compatible model server behind api gateway with iam auth enabled.
- It seems it’s not possible to work with the models classes available in strands sdk as we don’t have a great way to pass Sigv4 auth request.
- Keep in mind that I am not talking to bedrock or don’t want to use mcp client. its strands agent direct model interaction with customized model which I am trying to serve on self hosted servers but like it to be authenticated.
- I did try to create a custom model class which inherits from our base model and then write the client with SigV4 and it works. But it would be great if we can build this with OpenAI api compatible eco system.
Proposed Solution
Write a custom model Impl of our model class which works similar to opneai compatible model input and output. This model have a client built with SigV4 auth but keeps the originality of openai structures.
One simple Impl could be as follows but its not extensive or production ready with all the features customer needs.
class IAMAuthModel(Model):
"""Custom model implementation that uses AWS SigV4 auth."""
def __init__(
self,
model_id: str,
api_id: str,
region: str,
auth: AWSRequestsAuth,
max_tokens: int = 500,
temperature: float = 0.7,
):
"""Initialize the model.
Args:
model_id: ID of the model to use
api_id: API Gateway ID
region: AWS region
auth: AWSRequestsAuth object
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
"""
self.model_id = model_id
self.api_id = api_id
self.region = region
self.auth = auth
self.max_tokens = max_tokens
self.temperature = temperature
self.base_url = f"https://{api_id}.execute-api.{region}.amazonaws.com/v1"
# Config will store parameter values
self.config = {
"model_id": model_id,
"max_tokens": max_tokens,
"temperature": temperature,
}
logger.info(f"Initialized IAMAuthModel with endpoint: {self.base_url}")
def update_config(self, **model_config: Any) -> None:
"""Update model configuration.
Args:
**model_config: Configuration parameters
"""
self.config.update(model_config)
# Update specific attributes
if "model_id" in model_config:
self.model_id = model_config["model_id"]
if "max_tokens" in model_config:
self.max_tokens = model_config["max_tokens"]
if "temperature" in model_config:
self.temperature = model_config["temperature"]
def get_config(self) -> Dict[str, Any]:
"""Get the current model configuration.
Returns:
Dict containing configuration parameters
"""
return self.config
def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> str:
"""Format messages into a prompt string.
Args:
messages: List of message objects
system_prompt: Optional system prompt
Returns:
Formatted prompt string
"""
parts = []
# Add system prompt if provided
if system_prompt:
parts.append(f"System: {system_prompt}")
# Format each message
for message in messages:
role = "User:" if message["role"] == "user" else "Assistant:"
# Extract text from content blocks
texts = []
for block in message["content"]:
if "text" in block:
texts.append(block["text"])
# Join all text from this message
content = " ".join(texts)
parts.append(f"{role} {content}")
# Add a final assistant prompt if the last message was from user
if messages and messages[-1]["role"] == "user":
parts.append("Assistant:")
return "\n\n".join(parts)
async def stream(
self,
messages: Messages,
tool_specs: Optional[List[ToolSpec]] = None,
system_prompt: Optional[str] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
"""Stream response from the model.
Args:
messages: List of message objects
tool_specs: Tool specifications (not supported)
system_prompt: System prompt
**kwargs: Additional parameters
Yields:
Stream events containing model output
"""
# Format messages into a prompt
prompt = self._format_messages(messages, system_prompt)
logger.info(f"Formatted prompt: {prompt[:100]}...")
# Start message
yield {"messageStart": {"role": "assistant"}}
yield {"contentBlockStart": {"start": {}}}
try:
# Use requests with AWSRequestsAuth for API Gateway
# Strands expects async streaming, but we'll use a direct request approach
# and simulate streaming from the result
# Prepare the payload
payload = {
"model": self.model_id,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
}
# Add any additional parameters from config
for key, value in self.config.items():
if key not in ["model_id", "max_tokens", "temperature"] and key not in payload:
payload[key] = value
# Make the API request
logger.info(f"Sending request to {self.base_url}/completions")
response = requests.post(
f"{self.base_url}/completions",
json=payload,
auth=self.auth,
)
# Check if the request was successful
if response.status_code == 200:
# Parse the response
result = response.json()
# Extract the generated text
if "choices" in result and len(result["choices"]) > 0:
text = result["choices"][0]["text"]
# Simulate streaming by yielding content delta
yield {"contentBlockDelta": {"delta": {"text": text}}}
# Extract usage information if available
usage = result.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
else:
logger.error(f"Unexpected response format: {result}")
yield {"contentBlockDelta": {"delta": {"text": "Error: Unexpected response format"}}}
else:
error_text = f"Error: {response.status_code} - {response.text}"
logger.error(error_text)
yield {"contentBlockDelta": {"delta": {"text": error_text}}}
except Exception as e:
error_text = f"Error: {str(e)}"
logger.error(error_text)
yield {"contentBlockDelta": {"delta": {"text": error_text}}}
# End message
yield {"contentBlockStop": {}}
yield {"messageStop": {"stopReason": "end_turn"}}
# Include metadata if available
if response.status_code == 200 and "usage" in result:
yield {
"metadata": {
"usage": {
"inputTokens": input_tokens,
"outputTokens": output_tokens,
"totalTokens": total_tokens,
},
"metrics": {"latencyMs": 0},
}
}
async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs):
"""Not implemented for this model."""
raise NotImplementedError("Structured output is not supported by this model")Use Case
- The fine tuned model hosted in a K8s/EKS cluster and we would like to serve it behinds an authenticated proxy or sigV4 servers or use APIG like systems.
Alternatives Solutions
No response
Additional Context
No response
bhks, lyzustc and westonbrown
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request