diff --git a/pyproject.toml b/pyproject.toml index 8a95ba04c..39f21c690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,9 @@ otel = [ writer = [ "writer-sdk>=2.2.0,<3.0.0" ] +clova = [ + "httpx>=0.27.0,<1.0.0", +] sagemaker = [ "boto3>=1.26.0,<2.0.0", @@ -108,7 +111,7 @@ a2a = [ "starlette>=0.46.2,<1.0.0", ] all = [ - "strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]", + "strands-agents[a2a,anthropic,clova,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]", ] [tool.hatch.version] @@ -116,7 +119,7 @@ all = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["anthropic", "clova", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -139,7 +142,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["anthropic", "clova", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -155,7 +158,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["dev", "docs", "anthropic", "clova", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] diff --git a/src/strands/models/clova.py b/src/strands/models/clova.py new file mode 100644 index 000000000..7b7276938 --- /dev/null +++ b/src/strands/models/clova.py @@ -0,0 +1,261 @@ +"""CLOVA Studio model provider for Strands Agents SDK.""" + +import json +import logging +import os +from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Type, Union + +import httpx +from pydantic import BaseModel + +from ..types.content import Messages +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec +from .model import Model + +logger = logging.getLogger(__name__) + + +class ClovaModelException(Exception): + """Exception for CLOVA model errors.""" + + pass + + +class ClovaModel(Model): + """CLOVA Studio model provider implementation.""" + + def __init__( + self, + model: str = "HCX-005", + api_key: Optional[str] = None, + temperature: float = 0.7, + max_tokens: int = 4096, + top_p: float = 0.8, + top_k: int = 0, + repeat_penalty: float = 1.1, + stop: Optional[List[str]] = None, + **kwargs: Any, + ): + """Initialize CLOVA model. + + Args: + model: Model ID (default: HCX-005) + api_key: CLOVA API key (can be set via CLOVA_API_KEY env var) + temperature: Sampling temperature (0.0-1.0) + max_tokens: Maximum number of tokens to generate + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter + repeat_penalty: Repetition penalty + stop: List of stop sequences + **kwargs: Additional parameters + """ + self.model = model + self.api_key = api_key or os.getenv("CLOVA_API_KEY") + + if not self.api_key: + raise ValueError( + "CLOVA API key is required. Set CLOVA_API_KEY environment variable or pass api_key parameter." + ) + + self.temperature = temperature + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.repeat_penalty = repeat_penalty + self.stop = stop or [] + self.base_url = f"https://clovastudio.stream.ntruss.com/v3/chat-completions/{model}" + + # Store additional kwargs for future use + self.additional_params = kwargs + + def update_config(self, **model_config: Any) -> None: + """Update the model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + for key, value in model_config.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + self.additional_params[key] = value + + def get_config(self) -> Dict[str, Any]: + """Return the model configuration. + + Returns: + The model's configuration. + """ + return { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "top_k": self.top_k, + "repeat_penalty": self.repeat_penalty, + "stop": self.stop, + **self.additional_params, + } + + async def stream( + self, + messages: Union[Messages, str], + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + """Stream responses from CLOVA model. + + Args: + messages: Messages to be processed by the model. + tool_specs: List of tool specifications (not yet supported). + system_prompt: Optional system message. + **kwargs: Additional parameters. + + Yields: + Formatted message chunks from the model. + """ + if tool_specs: + logger.warning("Tool specs are not yet supported for CLOVA models") + + # Convert messages to CLOVA format + clova_messages = [] + + if system_prompt: + clova_messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) + + # Handle both Messages type and simple string + if isinstance(messages, str): + clova_messages.append({"role": "user", "content": [{"type": "text", "text": messages}]}) + elif hasattr(messages, "__iter__"): + for msg in messages: + if hasattr(msg, "role") and hasattr(msg, "content"): + # Convert content to CLOVA format + content = msg.content if isinstance(msg.content, str) else str(msg.content) + clova_messages.append({"role": msg.role, "content": [{"type": "text", "text": content}]}) + else: + # Fallback for dict-like messages + if isinstance(msg, dict) and "content" in msg: + if isinstance(msg["content"], str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + clova_messages.append(msg) + + request_body = { + "messages": clova_messages, + "temperature": kwargs.get("temperature", self.temperature), + "maxTokens": kwargs.get("max_tokens", self.max_tokens), + "topP": kwargs.get("top_p", self.top_p), + "topK": kwargs.get("top_k", self.top_k), + "repetitionPenalty": kwargs.get("repeat_penalty", self.repeat_penalty), + "stop": kwargs.get("stop", self.stop), + "seed": kwargs.get("seed", 0), + "includeAiFilters": kwargs.get("includeAiFilters", True), + "stream": True, + } + + # Add any additional parameters from initialization or kwargs + for key, value in self.additional_params.items(): + if key not in request_body: + request_body[key] = value + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.api_key}", + "Accept": "text/event-stream", + } + + # Add required request ID header + request_id = os.getenv("CLOVA_REQUEST_ID", "test-request-001") + headers["X-NCP-CLOVASTUDIO-REQUEST-ID"] = request_id + + async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: + response = await client.post( + self.base_url, + json=request_body, + headers=headers, + ) + + if response.status_code != 200: + error_text = await response.aread() + error_msg = f"CLOVA API request failed with status {response.status_code}: {error_text.decode('utf-8')}" + raise ClovaModelException(error_msg) + + # Process SSE stream + buffer = b"" + async for chunk in response.aiter_bytes(): + buffer += chunk + # Split by double newline which separates SSE events + events = buffer.split(b"\n\n") + # Keep the last incomplete event in buffer + buffer = events[-1] + + # Process complete events + for event in events[:-1]: + if not event: + continue + + # Parse SSE event + lines = event.split(b"\n") + data_line = None + for line in lines: + if line.startswith(b"data:"): + data_line = line[5:].strip() + break + + if not data_line: + continue + + try: + data_str = data_line.decode("utf-8") + data = json.loads(data_str) + + # Handle different event types and convert to StreamEvent + # CLOVA returns content in message.content format + if "message" in data and data["message"].get("content"): + # Yield as a StreamEvent dict with text chunk + yield { + "type": "text", + "text": data["message"]["content"], + } + + # Check for finish reason (not stopReason) + if "finishReason" in data and data["finishReason"] == "stop": + # Yield completion event + yield { + "type": "message_stop", + "stop_reason": "stop", + } + break + + except (json.JSONDecodeError, KeyError, UnicodeDecodeError): + # Skip malformed data + continue + + async def structured_output( + self, + output_model: Type[BaseModel], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Union[BaseModel, Any]], None]: + """Get structured output from the model. + + Note: This is not yet implemented for CLOVA models. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: Structured output is not yet supported for CLOVA models. + """ + raise NotImplementedError("Structured output is not yet supported for CLOVA models") + # Make this a generator (unreachable code, but satisfies type hint) + yield # pragma: no cover + + def __str__(self) -> str: + """String representation of the model.""" + return f"ClovaModel(model='{self.model}', temperature={self.temperature}, max_tokens={self.max_tokens})" diff --git a/tests/strands/models/test_clova.py b/tests/strands/models/test_clova.py new file mode 100644 index 000000000..856003eb4 --- /dev/null +++ b/tests/strands/models/test_clova.py @@ -0,0 +1,156 @@ +"""Unit tests for CLOVA model provider.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from strands.models.clova import ClovaModel, ClovaModelException + + +@pytest.fixture +def clova_model(): + """Create a ClovaModel instance for testing.""" + return ClovaModel(api_key="test-api-key", model="HCX-005") + + +def test_initialization(): + """Test ClovaModel initialization.""" + model = ClovaModel(api_key="test-key", model="HCX-005") + assert model.api_key == "test-key" + assert model.model == "HCX-005" + assert model.temperature == 0.7 + assert model.max_tokens == 4096 + + +def test_initialization_with_params(): + """Test ClovaModel initialization with custom parameters.""" + model = ClovaModel( + api_key="test-key", + model="HCX-005", + temperature=0.5, + max_tokens=2048, + top_p=0.9, + ) + assert model.temperature == 0.5 + assert model.max_tokens == 2048 + assert model.top_p == 0.9 + + +def test_initialization_without_api_key(): + """Test ClovaModel initialization without API key.""" + with pytest.raises(ValueError, match="CLOVA API key is required"): + ClovaModel(model="HCX-005") + + +@pytest.mark.asyncio +async def test_stream_with_successful_response(clova_model): + """Test streaming with successful response.""" + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/event-stream"} + + # Mock SSE stream - CLOVA format with message.content + async def mock_aiter(): + yield b'data: {"message":{"content":"Hello"}}\n\n' + yield b'data: {"message":{"content":" world"}}\n\n' + yield b'data: {"finishReason":"stop"}\n\n' + + mock_response.aiter_bytes = mock_aiter + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_instance + + result = "" + async for event in clova_model.stream("Test prompt"): + if hasattr(event, "type") and event.type == "text": + result += event.text + elif isinstance(event, dict) and event.get("type") == "text": + result += event["text"] + + assert result == "Hello world" + + +@pytest.mark.asyncio +async def test_stream_with_error_response(clova_model): + """Test streaming with error response.""" + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.text = AsyncMock(return_value="Unauthorized") + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_instance + + with pytest.raises(ClovaModelException, match="CLOVA API request failed"): + async for _ in clova_model.stream("Test prompt"): + pass + + +@pytest.mark.asyncio +async def test_stream_with_system_message(clova_model): + """Test streaming with system message.""" + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/event-stream"} + + async def mock_aiter(): + yield b'data: {"message":{"content":"Response"}}\n\n' + yield b'data: {"finishReason":"stop"}\n\n' + + mock_response.aiter_bytes = mock_aiter + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_instance + + result = "" + async for event in clova_model.stream("Test prompt", system_prompt="You are a helpful assistant"): + if hasattr(event, "type") and event.type == "text": + result += event.text + elif isinstance(event, dict) and event.get("type") == "text": + result += event["text"] + + # Verify the system message was included in the request + call_args = mock_instance.post.call_args + json_data = call_args.kwargs["json"] + assert any(msg["role"] == "system" for msg in json_data["messages"]) + + +@pytest.mark.asyncio +async def test_structured_output_not_implemented(clova_model): + """Test that structured_output raises NotImplementedError.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + result: str + + with pytest.raises(NotImplementedError, match="Structured output is not yet supported for CLOVA models"): + # structured_output is an async generator, need to call it properly + async for _ in clova_model.structured_output(TestModel, "Test prompt"): + pass + + +def test_model_str_representation(clova_model): + """Test string representation of ClovaModel.""" + str_repr = str(clova_model) + assert "ClovaModel" in str_repr + assert "HCX-005" in str_repr + + +def test_update_config(clova_model): + """Test updating model configuration.""" + clova_model.update_config(temperature=0.5, max_tokens=2048) + assert clova_model.temperature == 0.5 + assert clova_model.max_tokens == 2048 + + +def test_get_config(clova_model): + """Test getting model configuration.""" + config = clova_model.get_config() + assert config["model"] == "HCX-005" + assert config["temperature"] == 0.7 + assert config["max_tokens"] == 4096 diff --git a/tests_integ/models/test_model_clova.py b/tests_integ/models/test_model_clova.py new file mode 100644 index 000000000..4cc43d52c --- /dev/null +++ b/tests_integ/models/test_model_clova.py @@ -0,0 +1,179 @@ +"""Integration tests for CLOVA model provider.""" + +import os + +import pytest + +from strands.models.clova import ClovaModel + + +@pytest.fixture +def clova_api_key(): + """Get CLOVA API key from environment.""" + api_key = os.getenv("CLOVA_API_KEY") + if not api_key: + pytest.skip("CLOVA_API_KEY not set") + return api_key + + +@pytest.fixture +def clova_model(clova_api_key): + """Create a ClovaModel instance for integration testing.""" + return ClovaModel(api_key=clova_api_key, model="HCX-005") + + +@pytest.mark.asyncio +async def test_basic_streaming(clova_model): + """Test basic streaming functionality with real API.""" + prompt = "안녕하세요. 한국어로 간단한 인사를 해주세요." + + response_chunks = [] + async for event in clova_model.stream(prompt): + if event.get("type") == "text": + response_chunks.append(event["text"]) + + # Check we got a response + assert len(response_chunks) > 0 + + # Combine chunks to get full response + full_response = "".join(response_chunks) + assert len(full_response) > 0 + + # Korean response should contain Korean characters + assert any(ord(char) >= 0xAC00 and ord(char) <= 0xD7AF for char in full_response) + + +@pytest.mark.asyncio +async def test_streaming_with_system_prompt(clova_model): + """Test streaming with system prompt.""" + prompt = "What is 2 + 2?" + system_prompt = "You are a helpful math tutor. Answer briefly." + + response_chunks = [] + async for event in clova_model.stream(prompt, system_prompt=system_prompt): + if event.get("type") == "text": + response_chunks.append(event["text"]) + + # Check we got a response + assert len(response_chunks) > 0 + + # Combine chunks and check for "4" in response + full_response = "".join(response_chunks).lower() + assert "4" in full_response or "four" in full_response + + +@pytest.mark.asyncio +async def test_temperature_parameter(clova_model): + """Test that temperature parameter affects output.""" + prompt = "Tell me a creative story in one sentence." + + # Low temperature (more deterministic) + clova_model.update_config(temperature=0.1) + response1_chunks = [] + async for event in clova_model.stream(prompt): + if event.get("type") == "text": + response1_chunks.append(event.text) + + response1 = "".join(response1_chunks) + + # High temperature (more creative) + clova_model.update_config(temperature=0.9) + response2_chunks = [] + async for event in clova_model.stream(prompt): + if event.get("type") == "text": + response2_chunks.append(event.text) + + response2 = "".join(response2_chunks) + + # Both should produce responses + assert len(response1) > 0 + assert len(response2) > 0 + + # Responses should likely be different (not guaranteed but highly probable) + # We just check that both produce valid responses + + +@pytest.mark.asyncio +async def test_max_tokens_limit(clova_model): + """Test that max_tokens parameter limits output.""" + prompt = "Count from 1 to 100 slowly, one number per line." + + # Set a low token limit + clova_model.update_config(max_tokens=50) + + response_chunks = [] + async for event in clova_model.stream(prompt): + if event.get("type") == "text": + response_chunks.append(event["text"]) + + # Check we got a response + assert len(response_chunks) > 0 + + # Response should be limited (not reach 100) + full_response = "".join(response_chunks) + assert "100" not in full_response # Shouldn't reach 100 with token limit + + +@pytest.mark.asyncio +async def test_model_configuration(clova_model): + """Test getting and updating model configuration.""" + # Get initial config + initial_config = clova_model.get_config() + assert initial_config["model"] == "HCX-005" + assert "temperature" in initial_config + assert "max_tokens" in initial_config + + # Update config + clova_model.update_config(temperature=0.5, max_tokens=2048, top_p=0.8) + + # Verify updates + updated_config = clova_model.get_config() + assert updated_config["temperature"] == 0.5 + assert updated_config["max_tokens"] == 2048 + assert updated_config["top_p"] == 0.8 + + +@pytest.mark.asyncio +async def test_bilingual_support(clova_model): + """Test that CLOVA supports both Korean and English.""" + # Test Korean + korean_prompt = "한국의 수도는 어디인가요?" + korean_response = [] + async for event in clova_model.stream(korean_prompt): + if event.get("type") == "text": + korean_response.append(event.text) + + korean_text = "".join(korean_response).lower() + assert "서울" in korean_text or "seoul" in korean_text + + # Test English + english_prompt = "What is the capital of South Korea?" + english_response = [] + async for event in clova_model.stream(english_prompt): + if event.get("type") == "text": + english_response.append(event.text) + + english_text = "".join(english_response).lower() + assert "seoul" in english_text or "서울" in english_text + + +@pytest.mark.asyncio +async def test_structured_output_not_supported(clova_model): + """Test that structured output is not yet supported.""" + from pydantic import BaseModel + + class TestOutput(BaseModel): + answer: str + + with pytest.raises(NotImplementedError, match="Structured output is not yet supported"): + async for _ in clova_model.structured_output(TestOutput, "Test prompt"): + pass + + +def test_model_string_representation(clova_model): + """Test string representation of the model.""" + model_str = str(clova_model) + assert "ClovaModel" in model_str + assert "HCX-005" in model_str + assert "temperature" in model_str.lower() + assert "max_tokens" in model_str.lower()