In [None]:
import json
import requests
from typing import Any, Iterable, Optional, TypedDict
from typing_extensions import Unpack, override
from strands.types.models import Model
from strands.types.content import Messages, Role
from strands.types.streaming import StreamEvent
from strands.types.tools import ToolSpec

In [None]:
import json
import requests
from typing import Any, Iterable, Optional, TypedDict
from typing_extensions import Unpack, override
from strands.types.models import Model
from strands.types.content import Messages, Role
from strands.types.streaming import StreamEvent
from strands.types.tools import ToolSpec
import urllib3

# Disable SSL warnings for self-signed certificates
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

class LLMGatewayModel(Model):
    """AI Gateway  model provider implementation."""
    
    class ModelConfig(TypedDict):
        model_id: str
        max_tokens_to_sample: Optional[int]
        temperature: Optional[float]
        top_p: Optional[float] 
        top_k: Optional[int]
        verify_ssl: Optional[bool]  # Added SSL verification option

    def __init__(
        self,
        api_key: str,
        base_url: str = "https://aigateway.com/v0/r1/model",  ##Your API end point
        **model_config: Unpack[ModelConfig]
    ) -> None:
        
        # Set default model config
        default_config = {
            "model_id": "anthropic.claude-3-7-sonnet-20250219-v1:0",
            "max_tokens_to_sample": 4096,
            "temperature": 0.7,
            "top_p": 0.9,
            "top_k": None,
            "verify_ssl": False,  # Default to False for internal APIs
            "timeout": 60,
        }
        default_config.update(model_config)
        
        self.config = LLMGatewayModel.ModelConfig(**default_config)
        self.api_key = api_key
        self.base_url = base_url
        
    @override
    def update_config(self, **model_config: Unpack[ModelConfig]) -> None:
        self.config.update(model_config)

    @override
    def get_config(self) -> ModelConfig:
        return self.config
    
    @override
    def format_request(self, messages: Messages, tools: list[ToolSpec] | None = None, system_prompt: str | None = None) -> Any:

        formatted_messages = []
        
        # Add system prompt as first message if provided
        if system_prompt:
            formatted_messages.append({
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": system_prompt
                    }
                ]
            })
        
        for message in messages:
            # Handle both dict-style and object-style messages
            if isinstance(message, dict):
                role = message.get("role")
                content = message.get("content")
            else:
                role = getattr(message, "role", None)
                content = getattr(message, "content", None)
            
            # Format content as required by AI Gateway
            if isinstance(content, str):
                # Convert string content to the required format
                formatted_content = [
                    {
                        "type": "text",
                        "text": content
                    }
                ]
            elif isinstance(content, list):
                # Content is already in list format, keep as is
                formatted_content = content
            else:
                # Fallback for other content types
                formatted_content = [
                    {
                        "type": "text",
                        "text": str(content)
                    }
                ]
            
            formatted_messages.append({
                "role": role,
                "content": formatted_content
            })
        
        # Build request payload
        request_payload = {
            "messages": [
    {
      "role": "user",
      "content": [
        {
          "type": "text",
          "text": "Hello, how are you today?"
        }
      ]
    }
  ],
            "max_tokens": self.config.get("max_tokens_to_sample", 4096),
            "anthropic_version": "bedrock-2023-05-31",
            "temperature": self.config.get("temperature", 0.7)
        }
        
        # Add optional parameters if they exist
        if self.config.get("top_p") is not None:
            request_payload["top_p"] = self.config["top_p"]
        if self.config.get("top_k") is not None:
            request_payload["top_k"] = self.config["top_k"]
            
        # Add tools if provided
        if tools:
            request_payload["tools"] = tools

        print(f"request_payload provided: {request_payload}")
            
        return request_payload
    
    @override
    def format_chunk(self, chunk: Any) -> StreamEvent:
        if isinstance(chunk, dict):
            chunk_type = chunk.get("type")
            
            if chunk_type == "message_start":
                return {
                    "type": "message_start",
                    "message": chunk.get("message", {})
                }
            elif chunk_type == "content_block_start":
                return {
                    "type": "content_block_start",
                    "index": chunk.get("index", 0),
                    "content_block": chunk.get("content_block", {})
                }
            elif chunk_type == "content_block_delta":
                return {
                    "type": "content_block_delta",
                    "index": chunk.get("index", 0),
                    "delta": chunk.get("delta", {})
                }
            elif chunk_type == "content_block_stop":
                return {
                    "type": "content_block_stop",
                    "index": chunk.get("index", 0)
                }
            elif chunk_type == "message_delta":
                return {
                    "type": "message_delta",
                    "delta": chunk.get("delta", {})
                }
            elif chunk_type == "message_stop":
                return {
                    "type": "message_stop"
                }
            else:
                # Handle unknown chunk types or text content
                return {
                    "type": "content_block_delta",
                    "index": 0,
                    "delta": {
                        "type": "text_delta",
                        "text": str(chunk)
                    }
                }
        else:
            # Handle string chunks
            return {
                "type": "content_block_delta",
                "index": 0,
                "delta": {
                    "type": "text_delta",
                    "text": str(chunk)
                }
            }
    
    @override
    def stream(self, request: Any) -> Iterable[Any]:
        # Build the complete URL
        model_id = self.config["model_id"]
        url = f"{self.base_url}/{model_id}/invoke"
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
            # Removed Accept header that was causing the issue
        }
        
        # SSL verification setting
        verify_ssl = self.config.get("verify_ssl", False)
        
        try:
            # Make streaming request with SSL verification option
            response = requests.post(
                url,
                headers=headers,
                json=request,
                stream=True,
                timeout=self.config.get("timeout", 60),
                verify=verify_ssl  # This handles SSL certificate verification
            )
            print(f"Request sent to {url} with headers: {headers}")
            print(f"Request body: {json.dumps(request)}")
            print(f"Response status code: {response.status_code}")
            print(f"Response : {response}")
            print(f"Response body: {response.text}")
            # Check for HTTP errors
            response.raise_for_status()
            
        except requests.exceptions.SSLError as e:
            print(f"SSL Error: {e}")
            print("Tip: Try setting verify_ssl=False when creating the model if using internal/self-signed certificates")
            raise
        except requests.exceptions.RequestException as e:
            print(f"Request Error: {e}")
            raise
        
        # Yield message start event first
        yield {
            "type": "message_start",
            "message": {
                "role": "assistant",
                "content": []
            }
        }
        
        # Process streaming response
        for line in response.iter_lines():
            if line:
                line = line.decode('utf-8').strip()
                
                # Handle Server-Sent Events format
                if line.startswith('data: '):
                    data_part = line[6:]  # Remove 'data: ' prefix
                    
                    if data_part == '[DONE]':
                        break
                        
                    try:
                        event_data = json.loads(data_part)
                        yield event_data
                    except json.JSONDecodeError:
                        # Handle non-JSON data
                        yield {
                            "type": "content_block_delta",
                            "index": 0,
                            "delta": {
                                "type": "text_delta",
                                "text": data_part
                            }
                        }

In [None]:
# Example usage and factory function
def create_llm_gateway_model(
    api_key: str,
    model_id: str = "anthropic.claude-3-7-sonnet-20250219-v1:0-pgo",
    verify_ssl: bool = False,  # Default to False for internal APIs
    **kwargs
) -> LLMGatewayModel:
    return LLMGatewayModel(
        api_key=api_key,
        model_id=model_id,
        verify_ssl=verify_ssl,
        **kwargs
    )

In [None]:
# Test the updated implementation
from strands import Agent

try:
    # Initialize your custom model provider
    gateway_model = create_llm_gateway_model(
        api_key= ''# Using test key for now
        model_id="anthropic.claude-3-7-sonnet-20250219-v1:0-pgo",
        max_tokens_to_sample=2000,
        temperature=0.7,
    )
    
    # Create a Strands agent using your model
    agent = Agent(model=gateway_model)
    
    print("✅ LLMGatewayModel instance created successfully!")
    print("✅ Strands Agent created successfully!")
    
    # Test the format_request method directly
    test_messages = [{"role": "user", "content": "Hello, how are you?"}]
    formatted_request = gateway_model.format_request(test_messages, tools=None, system_prompt="You are a helpful assistant.")
    print("✅ format_request method works correctly!")
    print(f"Formatted request: {formatted_request}")
    
    # Note: Actual agent call would require valid API key
    print("\n📝 Note: To test the full agent call, provide a valid AI Gateway API key.")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Production Usage Example with SSL Fix and Debugging
from strands import Agent
import json

try:
    # Setting verify_ssl=False to handle self-signed certificates
    gateway_model = create_llm_gateway_model(
        api_key= '',
        model_id="anthropic.claude-3-7-sonnet-20250219-v1:0-pgo",
        max_tokens_to_sample=2000,
        temperature=0.7,
        verify_ssl=False  # This fixes the SSL certificate verification error
    )

    print("✅ Model created successfully!")
    
    # Test the request formatting first
    test_messages = [{"role": "user", "content": "Hello, how are you today?"}]
    formatted_request = gateway_model.format_request(test_messages, tools=None, system_prompt=None)
    print(f"🔍 Formatted request: {json.dumps(formatted_request, indent=2)}")
    
    # Create agent and use it
    agent = Agent(model=gateway_model)
    print("✅ Agent created successfully!")
    
    response = agent("Hello, how are you today?")
    print(f"🤖 Response: {response}")
    
except requests.exceptions.HTTPError as e:
    print(f"❌ HTTP Error: {e}")
    if hasattr(e, 'response') and e.response is not None:
        print(f"📝 Response status: {e.response.status_code}")
        print(f"📝 Response headers: {dict(e.response.headers)}")
        try:
            error_body = e.response.text
            print(f"📝 Response body: {error_body}")
        except:
            print("📝 Could not read response body")
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()