Skip to content

[FEATURE] Support OpenAI API compatible vLLM/SGLang server behind SigV4 proxy server #929

@bhks

Description

@bhks

Problem Statement

Support OpenAI API compatible vLLM/SGLang server behind SigV4 proxy server

  • I was trying to work with strands agent and openai compatible model server behind api gateway with iam auth enabled.
  • It seems it’s not possible to work with the models classes available in strands sdk as we don’t have a great way to pass Sigv4 auth request.
  • Keep in mind that I am not talking to bedrock or don’t want to use mcp client. its strands agent direct model interaction with customized model which I am trying to serve on self hosted servers but like it to be authenticated.
  • I did try to create a custom model class which inherits from our base model and then write the client with SigV4 and it works. But it would be great if we can build this with OpenAI api compatible eco system.

Proposed Solution

Write a custom model Impl of our model class which works similar to opneai compatible model input and output. This model have a client built with SigV4 auth but keeps the originality of openai structures.

One simple Impl could be as follows but its not extensive or production ready with all the features customer needs.

class IAMAuthModel(Model):
    """Custom model implementation that uses AWS SigV4 auth."""
    
    def __init__(
        self, 
        model_id: str,
        api_id: str,
        region: str,
        auth: AWSRequestsAuth,
        max_tokens: int = 500,
        temperature: float = 0.7,
    ):
        """Initialize the model.
        
        Args:
            model_id: ID of the model to use
            api_id: API Gateway ID
            region: AWS region
            auth: AWSRequestsAuth object
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
        """
        self.model_id = model_id
        self.api_id = api_id
        self.region = region
        self.auth = auth
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.base_url = f"https://{api_id}.execute-api.{region}.amazonaws.com/v1"
        
        # Config will store parameter values
        self.config = {
            "model_id": model_id,
            "max_tokens": max_tokens,
            "temperature": temperature,
        }
        
        logger.info(f"Initialized IAMAuthModel with endpoint: {self.base_url}")
    
    def update_config(self, **model_config: Any) -> None:
        """Update model configuration.
        
        Args:
            **model_config: Configuration parameters
        """
        self.config.update(model_config)
        
        # Update specific attributes
        if "model_id" in model_config:
            self.model_id = model_config["model_id"]
        if "max_tokens" in model_config:
            self.max_tokens = model_config["max_tokens"]
        if "temperature" in model_config:
            self.temperature = model_config["temperature"]
    
    def get_config(self) -> Dict[str, Any]:
        """Get the current model configuration.
        
        Returns:
            Dict containing configuration parameters
        """
        return self.config
    
    def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> str:
        """Format messages into a prompt string.
        
        Args:
            messages: List of message objects
            system_prompt: Optional system prompt
            
        Returns:
            Formatted prompt string
        """
        parts = []
        
        # Add system prompt if provided
        if system_prompt:
            parts.append(f"System: {system_prompt}")
        
        # Format each message
        for message in messages:
            role = "User:" if message["role"] == "user" else "Assistant:"
            
            # Extract text from content blocks
            texts = []
            for block in message["content"]:
                if "text" in block:
                    texts.append(block["text"])
            
            # Join all text from this message
            content = " ".join(texts)
            parts.append(f"{role} {content}")
        
        # Add a final assistant prompt if the last message was from user
        if messages and messages[-1]["role"] == "user":
            parts.append("Assistant:")
        
        return "\n\n".join(parts)
    
    async def stream(
        self,
        messages: Messages,
        tool_specs: Optional[List[ToolSpec]] = None,
        system_prompt: Optional[str] = None,
        **kwargs: Any,
    ) -> AsyncIterator[StreamEvent]:
        """Stream response from the model.
        
        Args:
            messages: List of message objects
            tool_specs: Tool specifications (not supported)
            system_prompt: System prompt
            **kwargs: Additional parameters
            
        Yields:
            Stream events containing model output
        """
        # Format messages into a prompt
        prompt = self._format_messages(messages, system_prompt)
        logger.info(f"Formatted prompt: {prompt[:100]}...")
        
        # Start message
        yield {"messageStart": {"role": "assistant"}}
        yield {"contentBlockStart": {"start": {}}}
        
        try:
            # Use requests with AWSRequestsAuth for API Gateway
            # Strands expects async streaming, but we'll use a direct request approach
            # and simulate streaming from the result
            
            # Prepare the payload
            payload = {
                "model": self.model_id,
                "prompt": prompt,
                "max_tokens": kwargs.get("max_tokens", self.max_tokens),
                "temperature": kwargs.get("temperature", self.temperature),
            }
            
            # Add any additional parameters from config
            for key, value in self.config.items():
                if key not in ["model_id", "max_tokens", "temperature"] and key not in payload:
                    payload[key] = value
            
            # Make the API request
            logger.info(f"Sending request to {self.base_url}/completions")
            response = requests.post(
                f"{self.base_url}/completions",
                json=payload,
                auth=self.auth,
            )
            
            # Check if the request was successful
            if response.status_code == 200:
                # Parse the response
                result = response.json()
                
                # Extract the generated text
                if "choices" in result and len(result["choices"]) > 0:
                    text = result["choices"][0]["text"]
                    
                    # Simulate streaming by yielding content delta
                    yield {"contentBlockDelta": {"delta": {"text": text}}}
                    
                    # Extract usage information if available
                    usage = result.get("usage", {})
                    input_tokens = usage.get("prompt_tokens", 0)
                    output_tokens = usage.get("completion_tokens", 0)
                    total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
                else:
                    logger.error(f"Unexpected response format: {result}")
                    yield {"contentBlockDelta": {"delta": {"text": "Error: Unexpected response format"}}}
            else:
                error_text = f"Error: {response.status_code} - {response.text}"
                logger.error(error_text)
                yield {"contentBlockDelta": {"delta": {"text": error_text}}}
                
        except Exception as e:
            error_text = f"Error: {str(e)}"
            logger.error(error_text)
            yield {"contentBlockDelta": {"delta": {"text": error_text}}}
        
        # End message
        yield {"contentBlockStop": {}}
        yield {"messageStop": {"stopReason": "end_turn"}}
        
        # Include metadata if available
        if response.status_code == 200 and "usage" in result:
            yield {
                "metadata": {
                    "usage": {
                        "inputTokens": input_tokens,
                        "outputTokens": output_tokens,
                        "totalTokens": total_tokens,
                    },
                    "metrics": {"latencyMs": 0},
                }
            }
    
    async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs):
        """Not implemented for this model."""
        raise NotImplementedError("Structured output is not supported by this model")

Use Case

  • The fine tuned model hosted in a K8s/EKS cluster and we would like to serve it behinds an authenticated proxy or sigV4 servers or use APIG like systems.

Alternatives Solutions

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions