Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions pydantic_ai_slim/pydantic_ai/ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Generic,
Protocol,
TypeVar,
cast,
runtime_checkable,
)

Expand Down Expand Up @@ -43,7 +44,6 @@
'StateDeps',
]


RunInputT = TypeVar('RunInputT')
"""Type variable for protocol-specific run input types."""

Expand All @@ -53,10 +53,12 @@
EventT = TypeVar('EventT')
"""Type variable for protocol-specific event types."""


StateT = TypeVar('StateT', bound=BaseModel)
"""Type variable for the state type, which must be a subclass of `BaseModel`."""

DispatchDepsT = TypeVar('DispatchDepsT')
"""TypeVar for deps to avoid awkwardness with unbound classvar deps."""


@runtime_checkable
class StateHandler(Protocol):
Expand Down Expand Up @@ -328,18 +330,18 @@ async def dispatch_request(
cls,
request: Request,
*,
agent: AbstractAgent[AgentDepsT, OutputDataT],
agent: AbstractAgent[DispatchDepsT, OutputDataT],
message_history: Sequence[ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: Model | KnownModelName | str | None = None,
instructions: Instructions[AgentDepsT] = None,
deps: AgentDepsT = None,
instructions: Instructions[DispatchDepsT] = None,
deps: DispatchDepsT = None,
output_type: OutputSpec[Any] | None = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
toolsets: Sequence[AbstractToolset[DispatchDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
on_complete: OnCompleteFunc[EventT] | None = None,
) -> Response:
Expand Down Expand Up @@ -376,7 +378,11 @@ async def dispatch_request(
) from e

try:
adapter = await cls.from_request(request, agent=agent)
# The DepsT comes from `agent`, not from `cls`; the cast is necessary to explain this to pyright
adapter = cast(
UIAdapter[RunInputT, MessageT, EventT, DispatchDepsT, OutputDataT],
await cls.from_request(request, agent=agent),
)
except ValidationError as e: # pragma: no cover
return Response(
content=e.json(),
Expand Down