diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index c7b0b2caa..43836fe34 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -350,8 +350,11 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non Returns: The extracted usage metrics and latency. """ - usage = Usage(**event["usage"]) - metrics = Metrics(**event["metrics"]) + # MetadataEvent has total=False, making all fields optional, but Usage and Metrics types + # have Required fields. Provide defaults to handle cases where custom models don't + # provide usage/metrics (e.g., when latency info is unavailable). + usage = Usage(**{"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, **event.get("usage", {})}) + metrics = Metrics(**{"latencyMs": 0, **event.get("metrics", {})}) if time_to_first_byte_ms: metrics["timeToFirstByteMs"] = time_to_first_byte_ms diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 714fbac27..3f5a6c998 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -421,6 +421,43 @@ def test_extract_usage_metrics_with_cache_tokens(): assert tru_usage == exp_usage and tru_metrics == exp_metrics +def test_extract_usage_metrics_without_metrics(): + """Test extract_usage_metrics when metrics field is missing.""" + event = { + "usage": {"inputTokens": 5, "outputTokens": 2, "totalTokens": 7}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 5, "outputTokens": 2, "totalTokens": 7} + exp_metrics = {"latencyMs": 0} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + +def test_extract_usage_metrics_without_usage(): + """Test extract_usage_metrics when usage field is missing.""" + event = { + "metrics": {"latencyMs": 100}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} + exp_metrics = {"latencyMs": 100} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + +def test_extract_usage_metrics_empty_metadata(): + """Test extract_usage_metrics when both fields are missing.""" + event = {} + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} + exp_metrics = {"latencyMs": 0} + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + @pytest.mark.parametrize( ("response", "exp_events"), [