diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index a98d4232e9..909149b838 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -15,7 +15,14 @@ from pydantic.json_schema import JsonSchemaValue from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict +from pydantic_graph._utils import AbstractSpan + +AbstractSpan = AbstractSpan + if TYPE_CHECKING: + from pydantic_ai.agent import AgentRun, AgentRunResult + from pydantic_graph import GraphRun, GraphRunResult + from . import messages as _messages from .tools import ObjectJsonSchema @@ -281,3 +288,16 @@ async def __anext__(self) -> T: except StopAsyncIteration: self._exhausted = True raise + + +def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str: + import logfire + import logfire_api + from logfire.experimental.annotations import get_traceparent + + span: AbstractSpan | None = x._span(required=False) # type: ignore[reportPrivateUsage] + if not span: # pragma: no cover + return '' + if isinstance(span, logfire_api.LogfireSpan): # pragma: no cover + assert isinstance(span, logfire.LogfireSpan) + return get_traceparent(span) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 096341fdb8..59f04314ae 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,7 +10,7 @@ from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import TypeGuard, TypeVar, deprecated +from typing_extensions import Literal, TypeGuard, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -26,6 +26,7 @@ result, usage as _usage, ) +from ._utils import AbstractSpan from .models.instrumented import InstrumentationSettings, InstrumentedModel from .result import FinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings @@ -52,6 +53,7 @@ if TYPE_CHECKING: from pydantic_ai.mcp import MCPServer + __all__ = ( 'Agent', 'AgentRun', @@ -1402,6 +1404,16 @@ async def main(): _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] + @overload + def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ... + @overload + def _span(self) -> AbstractSpan: ... + def _span(self, *, required: bool = True) -> AbstractSpan | None: + span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage] + if span is None and required: # pragma: no cover + raise AttributeError('Span is not available for this agent run') + return span + @property def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: """The current context of the agent run.""" @@ -1439,6 +1451,7 @@ def result(self) -> AgentRunResult[ResultDataT] | None: graph_run_result.output.tool_name, graph_run_result.state, self._graph_run.deps.new_message_index, + self._graph_run._span(required=False), # type: ignore[reportPrivateUsage] ) def __aiter__( @@ -1552,6 +1565,16 @@ class AgentRunResult(Generic[ResultDataT]): _result_tool_name: str | None = dataclasses.field(repr=False) _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False) _new_message_index: int = dataclasses.field(repr=False) + _span_value: AbstractSpan | None = dataclasses.field(repr=False) + + @overload + def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ... + @overload + def _span(self) -> AbstractSpan: ... + def _span(self, *, required: bool = True) -> AbstractSpan | None: + if self._span_value is None and required: # pragma: no cover + raise AttributeError('Span is not available for this agent run') + return self._span_value def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: """Set return content for the result tool. diff --git a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py index 53587bcbb3..cd83f30a99 100644 --- a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py +++ b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py @@ -14,7 +14,7 @@ try: from logfire._internal.tracer import ( - ProxyTracerProvider as LogfireProxyTracerProvider, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage] + ProxyTracerProvider as LogfireProxyTracerProvider, # pyright: ignore ) _LOGFIRE_IS_INSTALLED = True diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index bc49ed6fcb..d797a4516a 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -5,10 +5,14 @@ from functools import partial from typing import Any, Callable, TypeVar -from typing_extensions import ParamSpec, TypeIs, get_args, get_origin +from logfire_api import LogfireSpan +from opentelemetry.trace import Span +from typing_extensions import ParamSpec, TypeAlias, TypeIs, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin +AbstractSpan: TypeAlias = 'LogfireSpan | Span' + def get_event_loop(): try: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 3da34bca9a..7336451407 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -6,15 +6,16 @@ from contextlib import AbstractContextManager, ExitStack, asynccontextmanager from dataclasses import dataclass, field from functools import cached_property -from typing import Any, Generic, cast +from typing import Any, Generic, cast, overload import logfire_api import typing_extensions -from logfire_api import LogfireSpan +from opentelemetry.trace import Span from typing_extensions import deprecated from typing_inspection import typing_objects from . import _utils, exceptions, mermaid +from ._utils import AbstractSpan from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT from .persistence import BaseStatePersistence from .persistence.in_mem import SimpleStatePersistence @@ -125,7 +126,6 @@ async def run( deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, infer_name: bool = True, - span: LogfireSpan | None = None, ) -> GraphRunResult[StateT, RunEndT]: """Run the graph from a starting node until it ends. @@ -137,8 +137,6 @@ async def run( persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. - span: The span to use for the graph run. If not provided, a span will be created depending on the value of - the `auto_instrument` field. Returns: A `GraphRunResult` containing information about the run, including its final result. @@ -164,7 +162,7 @@ async def main(): self._infer_name(inspect.currentframe()) async with self.iter( - start_node, state=state, deps=deps, persistence=persistence, span=span, infer_name=False + start_node, state=state, deps=deps, persistence=persistence, infer_name=False ) as graph_run: async for _node in graph_run: pass @@ -214,7 +212,7 @@ async def iter( state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, - span: AbstractContextManager[Any] | None = None, + span: AbstractContextManager[Span] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. @@ -252,14 +250,15 @@ async def iter( persistence = SimpleStatePersistence() persistence.set_graph_types(self) - if self.auto_instrument and span is None: - span = logfire_api.span('run graph {graph.name}', graph=self) - with ExitStack() as stack: - if span is not None: - stack.enter_context(span) + entered_span: AbstractSpan | None = None + if span is None: + if self.auto_instrument: + entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self)) + else: + entered_span = stack.enter_context(span) yield GraphRun[StateT, DepsT, RunEndT]( - graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps + graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps, span=entered_span ) @asynccontextmanager @@ -268,7 +267,7 @@ async def iter_from_persistence( persistence: BaseStatePersistence[StateT, RunEndT], *, deps: DepsT = None, - span: AbstractContextManager[Any] | None = None, + span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. @@ -301,8 +300,7 @@ async def iter_from_persistence( span = logfire_api.span('run graph {graph.name}', graph=self) with ExitStack() as stack: - if span is not None: - stack.enter_context(span) + entered_span = None if span is None else stack.enter_context(span) yield GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=snapshot.node, @@ -310,6 +308,7 @@ async def iter_from_persistence( state=snapshot.state, deps=deps, snapshot_id=snapshot.id, + span=entered_span, ) async def initialize( @@ -370,6 +369,7 @@ async def next( persistence=persistence, state=state, deps=deps, + span=None, ) return await run.next(node) @@ -644,6 +644,7 @@ def __init__( persistence: BaseStatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, + span: AbstractSpan | None, snapshot_id: str | None = None, ): """Create a new run for a given graph, starting at the specified node. @@ -658,6 +659,7 @@ def __init__( to all nodes via `ctx.state`. deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, configuration, or logging clients. + span: The span used for the graph run. snapshot_id: The ID of the snapshot the node came from. """ self.graph = graph @@ -666,9 +668,19 @@ def __init__( self.state = state self.deps = deps + self.__span = span self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node self._is_started: bool = False + @overload + def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ... + @overload + def _span(self) -> AbstractSpan: ... + def _span(self, *, required: bool = True) -> AbstractSpan | None: + if self.__span is None and required: # pragma: no cover + raise exceptions.GraphRuntimeError('No span available for this graph run.') + return self.__span + @property def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """The next node that will be run in the graph. @@ -682,10 +694,8 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None: """The final result of the graph run if the run is completed, otherwise `None`.""" if not isinstance(self._next_node, End): return None # The GraphRun has not finished running - return GraphRunResult( - self._next_node.data, - state=self.state, - persistence=self.persistence, + return GraphRunResult[StateT, RunEndT]( + self._next_node.data, state=self.state, persistence=self.persistence, span=self._span(required=False) ) async def next( @@ -793,10 +803,31 @@ def __repr__(self) -> str: return f'' -@dataclass +@dataclass(init=False) class GraphRunResult(Generic[StateT, RunEndT]): """The final result of running a graph.""" output: RunEndT state: StateT persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False) + + def __init__( + self, + output: RunEndT, + state: StateT, + persistence: BaseStatePersistence[StateT, RunEndT], + span: AbstractSpan | None = None, + ): + self.output = output + self.state = state + self.persistence = persistence + self.__span = span + + @overload + def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ... + @overload + def _span(self) -> AbstractSpan: ... + def _span(self, *, required: bool = True) -> AbstractSpan | None: # pragma: no cover + if self.__span is None and required: + raise exceptions.GraphRuntimeError('No span available for this graph run.') + return self.__span diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 30dde6be5c..74d51291b4 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -9,6 +9,7 @@ from typing_extensions import NotRequired, TypedDict from pydantic_ai import Agent +from pydantic_ai._utils import get_traceparent from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.models.test import TestModel @@ -262,3 +263,89 @@ def get_model(): Agent.instrument_all(False) assert get_model() is model + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.anyio +async def test_feedback(capfire: CaptureLogfire) -> None: + try: + from logfire.experimental.annotations import record_feedback + except ImportError: + pytest.skip('Requires recent version of logfire') + + my_agent = Agent(model=TestModel(), instrument=True) + + async with my_agent.iter('Hello') as agent_run: + async for _ in agent_run: + pass + result = agent_run.result + assert result + traceparent = get_traceparent(result) + assert traceparent == get_traceparent(agent_run) + assert traceparent == snapshot('00-00000000000000000000000000000001-0000000000000001-01') + record_feedback(traceparent, 'factuality', 0.1, comment='the agent lied', extra={'foo': 'bar'}) + + assert capfire.exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'chat test', + 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 3000000000, + 'attributes': { + 'gen_ai.operation.name': 'chat', + 'gen_ai.system': 'test', + 'gen_ai.request.model': 'test', + 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}', + 'logfire.span_type': 'span', + 'logfire.msg': 'chat test', + 'gen_ai.usage.input_tokens': 51, + 'gen_ai.usage.output_tokens': 4, + 'gen_ai.response.model': 'test', + 'events': '[{"content": "Hello", "role": "user", "gen_ai.system": "test", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}, {"index": 0, "message": {"role": "assistant", "content": "success (no tool calls)"}, "gen_ai.system": "test", "event.name": "gen_ai.choice"}]', + 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}, "model_request_parameters": {"type": "object"}}}', + }, + }, + { + 'name': 'agent run', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 4000000000, + 'attributes': { + 'model_name': 'test', + 'agent_name': 'agent', + 'logfire.msg': 'agent run', + 'logfire.span_type': 'span', + 'gen_ai.usage.input_tokens': 51, + 'gen_ai.usage.output_tokens': 4, + 'all_messages_events': '[{"content": "Hello", "role": "user", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}, {"role": "assistant", "content": "success (no tool calls)", "gen_ai.message.index": 1, "event.name": "gen_ai.assistant.message"}]', + 'final_result': 'success (no tool calls)', + 'logfire.json_schema': '{"type": "object", "properties": {"all_messages_events": {"type": "array"}, "final_result": {"type": "object"}}}', + }, + }, + { + 'name': 'feedback: factuality', + 'context': {'trace_id': 1, 'span_id': 5, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': True}, + 'start_time': 5000000000, + 'end_time': 5000000000, + 'attributes': { + 'logfire.span_type': 'annotation', + 'logfire.level_num': 9, + 'logfire.msg_template': 'feedback: factuality', + 'logfire.msg': 'feedback: factuality = 0.1', + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_feedback', + 'code.lineno': 123, + 'logfire.feedback.name': 'factuality', + 'factuality': 0.1, + 'foo': 'bar', + 'logfire.feedback.comment': 'the agent lied', + 'logfire.disable_console_log': True, + 'logfire.json_schema': '{"type":"object","properties":{"logfire.feedback.name":{},"factuality":{},"foo":{},"logfire.feedback.comment":{},"logfire.span_type":{},"logfire.disable_console_log":{}}}', + }, + }, + ] + ) diff --git a/uv.lock b/uv.lock index 235046b462..abf5589b97 100644 --- a/uv.lock +++ b/uv.lock @@ -1453,7 +1453,7 @@ wheels = [ [[package]] name = "logfire" -version = "3.11.0" +version = "3.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "executing" }, @@ -1465,9 +1465,9 @@ dependencies = [ { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/38/06/fcde41240092e5b28f4a213da653f773441fe9ac1cbbc59d45b16e691a15/logfire-3.11.0.tar.gz", hash = "sha256:2ce9e58d6a7eb6fb9888fe12464d2284d16479e49ade766846fcb84e1c4a1abe", size = 464360 } +sdist = { url = "https://files.pythonhosted.org/packages/9d/5b/ca57ad7a9e78fc9f7779bb2da8a3776233b832d211eaf49cea2582fe3f77/logfire-3.14.0.tar.gz", hash = "sha256:afdd23386a8a57da7ab97938cc5eec17928ce9195907b85860d906f04c5d33e3", size = 473733 } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/64/8bfc71ef2fa44190e7a8b10c3aabddf72df2944a0bce04bbe27a500566ac/logfire-3.11.0-py3-none-any.whl", hash = "sha256:bd20e697030510cfc834b1805945b1ad62ac057679d01d13579ed2f454f0e979", size = 190143 }, + { url = "https://files.pythonhosted.org/packages/7f/13/e06647ca3d7fb9167dd82260cc0e978ffcae88aa396025cea3d86a04875d/logfire-3.14.0-py3-none-any.whl", hash = "sha256:4f95cf98a7c29cd7cd00e093ba75ce1e4e19e5069acda8b1577a4b7790e0237a", size = 193168 }, ] [package.optional-dependencies] @@ -2955,7 +2955,7 @@ dev = [ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, - { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.34.116" }, + { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.13.11" }, { name = "duckduckgo-search", marker = "extra == 'duckduckgo'", specifier = ">=7.0.0" }, { name = "eval-type-backport", specifier = ">=0.2.0" },