diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..1763f5dec 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -8,11 +8,13 @@ from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm +from litellm.exceptions import ContextWindowExceededError from litellm.utils import supports_response_schema from pydantic import BaseModel from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys @@ -135,7 +137,11 @@ async def stream( logger.debug("request=<%s>", request) logger.debug("invoking model") - response = await litellm.acompletion(**self.client_args, **request) + try: + response = await litellm.acompletion(**self.client_args, **request) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -205,15 +211,24 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - if not supports_response_schema(self.get_config()["model_id"]): + supports_schema = supports_response_schema(self.get_config()["model_id"]) + + # If the provider does not support response schemas, we cannot reliably parse structured output. + # In that case we must not call the provider and must raise the documented ValueError. + if not supports_schema: raise ValueError("Model does not support response_format") - response = await litellm.acompletion( - **self.client_args, - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + # For providers that DO support response schemas, call litellm and map context-window errors. + try: + response = await litellm.acompletion( + **self.client_args, + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..776ae7bae 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -3,9 +3,11 @@ import pydantic import pytest +from litellm.exceptions import ContextWindowExceededError import strands from strands.models.litellm import LiteLLMModel +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -332,3 +334,13 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_context_window_maps_to_typed_exception(litellm_acompletion, model): + """Test that a typed ContextWindowExceededError is mapped correctly.""" + litellm_acompletion.side_effect = ContextWindowExceededError(message="test error", model="x", llm_provider="y") + + with pytest.raises(ContextWindowOverflowException): + async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): + pass