diff --git a/src/agentex/lib/utils/completions.py b/src/agentex/lib/utils/completions.py index f33495c1..fe62c7d1 100644 --- a/src/agentex/lib/utils/completions.py +++ b/src/agentex/lib/utils/completions.py @@ -6,6 +6,7 @@ from agentex.lib.types.llm_messages import ( Delta, + Usage, Choice, ToolCall, Completion, @@ -21,6 +22,8 @@ def _concat_chunks(_a: None, b: Any): @_concat_chunks.register def _(a: Completion, b: Completion) -> Completion: a.choices = [_concat_chunks(*c) for c in zip(a.choices, b.choices, strict=False)] + a.usage = _concat_chunks(a.usage, b.usage) + return a @@ -35,6 +38,17 @@ def _(a: Choice, b: Choice) -> Choice: a.finish_reason = a.finish_reason or b.finish_reason return a +@_concat_chunks.register +def _(a: Usage | None, b: Usage | None) -> Usage | None: + if a is not None and b is not None: + return Usage( + prompt_tokens=a.prompt_tokens + b.prompt_tokens, + completion_tokens=a.completion_tokens + b.completion_tokens, + total_tokens=a.total_tokens + b.total_tokens, + ) + else: + return a or b + @_concat_chunks.register def _(a: Delta, b: Delta) -> Delta: