Skip to content
20 changes: 20 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict

from pydantic_graph._utils import AbstractSpan

AbstractSpan = AbstractSpan

if TYPE_CHECKING:
from pydantic_ai.agent import AgentRun, AgentRunResult
from pydantic_graph import GraphRun, GraphRunResult

from . import messages as _messages
from .tools import ObjectJsonSchema

Expand Down Expand Up @@ -281,3 +288,16 @@ async def __anext__(self) -> T:
except StopAsyncIteration:
self._exhausted = True
raise


def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
import logfire
import logfire_api
from logfire.experimental.annotations import get_traceparent

span: AbstractSpan | None = x._span(required=False) # type: ignore[reportPrivateUsage]
if not span: # pragma: no cover
return ''
if isinstance(span, logfire_api.LogfireSpan): # pragma: no cover
assert isinstance(span, logfire.LogfireSpan)
return get_traceparent(span)
25 changes: 24 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from opentelemetry.trace import NoOpTracer, use_span
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import TypeGuard, TypeVar, deprecated
from typing_extensions import Literal, TypeGuard, TypeVar, deprecated

from pydantic_graph import End, Graph, GraphRun, GraphRunContext
from pydantic_graph._utils import get_event_loop
Expand All @@ -26,6 +26,7 @@
result,
usage as _usage,
)
from ._utils import AbstractSpan
from .models.instrumented import InstrumentationSettings, InstrumentedModel
from .result import FinalResult, ResultDataT, StreamedRunResult
from .settings import ModelSettings, merge_model_settings
Expand All @@ -52,6 +53,7 @@
if TYPE_CHECKING:
from pydantic_ai.mcp import MCPServer


__all__ = (
'Agent',
'AgentRun',
Expand Down Expand Up @@ -1402,6 +1404,16 @@ async def main():
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]
]

@overload
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
@overload
def _span(self) -> AbstractSpan: ...
def _span(self, *, required: bool = True) -> AbstractSpan | None:
span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage]
if span is None and required: # pragma: no cover
raise AttributeError('Span is not available for this agent run')
return span

@property
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
"""The current context of the agent run."""
Expand Down Expand Up @@ -1439,6 +1451,7 @@ def result(self) -> AgentRunResult[ResultDataT] | None:
graph_run_result.output.tool_name,
graph_run_result.state,
self._graph_run.deps.new_message_index,
self._graph_run._span(required=False), # type: ignore[reportPrivateUsage]
)

def __aiter__(
Expand Down Expand Up @@ -1552,6 +1565,16 @@ class AgentRunResult(Generic[ResultDataT]):
_result_tool_name: str | None = dataclasses.field(repr=False)
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
_new_message_index: int = dataclasses.field(repr=False)
_span_value: AbstractSpan | None = dataclasses.field(repr=False)

@overload
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
@overload
def _span(self) -> AbstractSpan: ...
def _span(self, *, required: bool = True) -> AbstractSpan | None:
if self._span_value is None and required: # pragma: no cover
raise AttributeError('Span is not available for this agent run')
return self._span_value

def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
"""Set return content for the result tool.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

try:
from logfire._internal.tracer import (
ProxyTracerProvider as LogfireProxyTracerProvider, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage]
ProxyTracerProvider as LogfireProxyTracerProvider, # pyright: ignore
)

_LOGFIRE_IS_INSTALLED = True
Expand Down
6 changes: 5 additions & 1 deletion pydantic_graph/pydantic_graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from functools import partial
from typing import Any, Callable, TypeVar

from typing_extensions import ParamSpec, TypeIs, get_args, get_origin
from logfire_api import LogfireSpan
from opentelemetry.trace import Span
from typing_extensions import ParamSpec, TypeAlias, TypeIs, get_args, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin

AbstractSpan: TypeAlias = 'LogfireSpan | Span'


def get_event_loop():
try:
Expand Down
73 changes: 52 additions & 21 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Generic, cast
from typing import Any, Generic, cast, overload

import logfire_api
import typing_extensions
from logfire_api import LogfireSpan
from opentelemetry.trace import Span
from typing_extensions import deprecated
from typing_inspection import typing_objects

from . import _utils, exceptions, mermaid
from ._utils import AbstractSpan
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
from .persistence import BaseStatePersistence
from .persistence.in_mem import SimpleStatePersistence
Expand Down Expand Up @@ -125,7 +126,6 @@ async def run(
deps: DepsT = None,
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
infer_name: bool = True,
span: LogfireSpan | None = None,
) -> GraphRunResult[StateT, RunEndT]:
"""Run the graph from a starting node until it ends.

Expand All @@ -137,8 +137,6 @@ async def run(
persistence: State persistence interface, defaults to
[`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
infer_name: Whether to infer the graph name from the calling frame.
span: The span to use for the graph run. If not provided, a span will be created depending on the value of
the `auto_instrument` field.

Returns:
A `GraphRunResult` containing information about the run, including its final result.
Expand All @@ -164,7 +162,7 @@ async def main():
self._infer_name(inspect.currentframe())

async with self.iter(
start_node, state=state, deps=deps, persistence=persistence, span=span, infer_name=False
start_node, state=state, deps=deps, persistence=persistence, infer_name=False
) as graph_run:
async for _node in graph_run:
pass
Expand Down Expand Up @@ -214,7 +212,7 @@ async def iter(
state: StateT = None,
deps: DepsT = None,
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
span: AbstractContextManager[Any] | None = None,
span: AbstractContextManager[Span] | None = None,
infer_name: bool = True,
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
Expand Down Expand Up @@ -252,14 +250,15 @@ async def iter(
persistence = SimpleStatePersistence()
persistence.set_graph_types(self)

if self.auto_instrument and span is None:
span = logfire_api.span('run graph {graph.name}', graph=self)

with ExitStack() as stack:
if span is not None:
stack.enter_context(span)
entered_span: AbstractSpan | None = None
if span is None:
if self.auto_instrument:
entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self))
else:
entered_span = stack.enter_context(span)
yield GraphRun[StateT, DepsT, RunEndT](
graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps
graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps, span=entered_span
)

@asynccontextmanager
Expand All @@ -268,7 +267,7 @@ async def iter_from_persistence(
persistence: BaseStatePersistence[StateT, RunEndT],
*,
deps: DepsT = None,
span: AbstractContextManager[Any] | None = None,
span: AbstractContextManager[AbstractSpan] | None = None,
infer_name: bool = True,
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
"""A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object.
Expand Down Expand Up @@ -301,15 +300,15 @@ async def iter_from_persistence(
span = logfire_api.span('run graph {graph.name}', graph=self)

with ExitStack() as stack:
if span is not None:
stack.enter_context(span)
entered_span = None if span is None else stack.enter_context(span)
yield GraphRun[StateT, DepsT, RunEndT](
graph=self,
start_node=snapshot.node,
persistence=persistence,
state=snapshot.state,
deps=deps,
snapshot_id=snapshot.id,
span=entered_span,
)

async def initialize(
Expand Down Expand Up @@ -370,6 +369,7 @@ async def next(
persistence=persistence,
state=state,
deps=deps,
span=None,
)
return await run.next(node)

Expand Down Expand Up @@ -644,6 +644,7 @@ def __init__(
persistence: BaseStatePersistence[StateT, RunEndT],
state: StateT,
deps: DepsT,
span: AbstractSpan | None,
snapshot_id: str | None = None,
):
"""Create a new run for a given graph, starting at the specified node.
Expand All @@ -658,6 +659,7 @@ def __init__(
to all nodes via `ctx.state`.
deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections,
configuration, or logging clients.
span: The span used for the graph run.
snapshot_id: The ID of the snapshot the node came from.
"""
self.graph = graph
Expand All @@ -666,9 +668,19 @@ def __init__(
self.state = state
self.deps = deps

self.__span = span
self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node
self._is_started: bool = False

@overload
def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ...
@overload
def _span(self) -> AbstractSpan: ...
def _span(self, *, required: bool = True) -> AbstractSpan | None:
if self.__span is None and required: # pragma: no cover
raise exceptions.GraphRuntimeError('No span available for this graph run.')
return self.__span

@property
def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
"""The next node that will be run in the graph.
Expand All @@ -682,10 +694,8 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
"""The final result of the graph run if the run is completed, otherwise `None`."""
if not isinstance(self._next_node, End):
return None # The GraphRun has not finished running
return GraphRunResult(
self._next_node.data,
state=self.state,
persistence=self.persistence,
return GraphRunResult[StateT, RunEndT](
self._next_node.data, state=self.state, persistence=self.persistence, span=self._span(required=False)
)

async def next(
Expand Down Expand Up @@ -793,10 +803,31 @@ def __repr__(self) -> str:
return f'<GraphRun graph={self.graph.name or "[unnamed]"}>'


@dataclass
@dataclass(init=False)
class GraphRunResult(Generic[StateT, RunEndT]):
"""The final result of running a graph."""

output: RunEndT
state: StateT
persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False)

def __init__(
self,
output: RunEndT,
state: StateT,
persistence: BaseStatePersistence[StateT, RunEndT],
span: AbstractSpan | None = None,
):
self.output = output
self.state = state
self.persistence = persistence
self.__span = span

@overload
def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ...
@overload
def _span(self) -> AbstractSpan: ...
def _span(self, *, required: bool = True) -> AbstractSpan | None: # pragma: no cover
if self.__span is None and required:
raise exceptions.GraphRuntimeError('No span available for this graph run.')
return self.__span
Loading