diff --git a/docs/components/agent.rst b/docs/components/agent.rst index 9c8c4b63e..0f7dca91b 100644 --- a/docs/components/agent.rst +++ b/docs/components/agent.rst @@ -314,6 +314,52 @@ If you want to expose the underlying error to the LLM, you can throw a custom ex } } +Tool Sources +~~~~~~~~~~~~ + +Some tools bring in data to the agent from external sources, like search engines or APIs. Those sources can be exposed +by enabling `keepToolSources` as argument of the :class:`Symfony\\AI\\Agent\\Toolbox\\AgentProcessor`:: + + use Symfony\AI\Agent\Toolbox\AgentProcessor; + use Symfony\AI\Agent\Toolbox\Toolbox; + + $toolbox = new Toolbox([new MyTool()]); + $toolProcessor = new AgentProcessor($toolbox, keepToolSources: true); + +In the tool implementation sources can be added by implementing the +:class:`Symfony\\AI\\Agent\\Toolbox\\Source\\HasSourcesInterface` in combination with the trait +:class:`Symfony\\AI\\Agent\\Toolbox\\Source\\HasSourcesTrait`:: + + use Symfony\AI\Agent\Toolbox\Source\HasSourcesInterface; + use Symfony\AI\Agent\Toolbox\Source\HasSourcesTrait; + + #[AsTool('my_tool', 'Example tool with sources.')] + final class MyTool implements HasSourcesInterface + { + use HasSourcesTrait; + + public function __invoke(string $query): string + { + // Add sources relevant for the result + + $this->addSource( + new Source('Example Source 1', 'https://example.com/source1', 'Relevant content from source 1'), + ); + + // return result + } + } + +The sources can be fetched from the metadata of the result after the agent execution:: + + $result = $agent->call($messages); + + foreach ($result->getMetadata()->get('sources', []) as $source) { + echo sprintf(' - %s (%s): %s', $source->getName(), $source->getReference(), $source->getContent()); + } + +See `Anthropic Toolbox Example`_ for a complete example using sources with Wikipedia tool. + Tool Filtering ~~~~~~~~~~~~~~ @@ -765,6 +811,7 @@ Code Examples .. _`Platform Component`: https://github.com/symfony/ai-platform +.. _`Anthropic Toolbox Example`: https://github.com/symfony/ai/blob/main/examples/anthropic/toolcall.php .. _`Brave Tool`: https://github.com/symfony/ai/blob/main/examples/toolbox/brave.php .. _`Clock Tool`: https://github.com/symfony/ai/blob/main/examples/toolbox/clock.php .. _`Crawler Tool`: https://github.com/symfony/ai/blob/main/examples/toolbox/brave.php diff --git a/examples/anthropic/toolcall.php b/examples/anthropic/toolcall.php index ea5127cb9..56c399a18 100644 --- a/examples/anthropic/toolcall.php +++ b/examples/anthropic/toolcall.php @@ -11,6 +11,7 @@ use Symfony\AI\Agent\Agent; use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\AI\Agent\Toolbox\Tool\Wikipedia; use Symfony\AI\Agent\Toolbox\Toolbox; use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory; @@ -23,10 +24,16 @@ $wikipedia = new Wikipedia(http_client()); $toolbox = new Toolbox([$wikipedia], logger: logger()); -$processor = new AgentProcessor($toolbox); +$processor = new AgentProcessor($toolbox, keepToolSources: true); $agent = new Agent($platform, 'claude-3-5-sonnet-20241022', [$processor], [$processor]); $messages = new MessageBag(Message::ofUser('Who is the current chancellor of Germany?')); $result = $agent->call($messages); -echo $result->getContent().\PHP_EOL; +echo $result->getContent().\PHP_EOL.\PHP_EOL; + +echo 'Used sources:'.\PHP_EOL; +foreach ($result->getMetadata()->get('sources', []) as $source) { + echo sprintf(' - %s (%s)', $source->getName(), $source->getReference()).\PHP_EOL; +} +echo \PHP_EOL; diff --git a/fixtures/Tool/ToolSources.php b/fixtures/Tool/ToolSources.php new file mode 100644 index 000000000..e225f4906 --- /dev/null +++ b/fixtures/Tool/ToolSources.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Source\HasSourcesInterface; +use Symfony\AI\Agent\Toolbox\Source\HasSourcesTrait; +use Symfony\AI\Agent\Toolbox\Source\Source; + +#[AsTool('tool_sources', 'Tool that records some sources')] +final class ToolSources implements HasSourcesInterface +{ + use HasSourcesTrait; + + /** + * @param string $query Search query + */ + public function __invoke(string $query): string + { + $foundContent = 'Content of that relevant article.'; + + $this->addSource( + new Source('Relevant Article', 'https://example.com/relevant-article', $foundContent), + ); + + return $foundContent; + } +} diff --git a/src/agent/src/Toolbox/AgentProcessor.php b/src/agent/src/Toolbox/AgentProcessor.php index 218406677..50db3e619 100644 --- a/src/agent/src/Toolbox/AgentProcessor.php +++ b/src/agent/src/Toolbox/AgentProcessor.php @@ -18,6 +18,7 @@ use Symfony\AI\Agent\Output; use Symfony\AI\Agent\OutputProcessorInterface; use Symfony\AI\Agent\Toolbox\Event\ToolCallsExecuted; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\AI\Agent\Toolbox\StreamResult as ToolboxStreamResponse; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Message\Message; @@ -34,11 +35,25 @@ final class AgentProcessor implements InputProcessorInterface, OutputProcessorIn { use AgentAwareTrait; + /** + * Sources get collected during tool calls on class level to be able to handle consecutive tool calls. + * They get added to the result metadata and reset when the outermost agent call is finished via nesting level. + * + * @var Source[] + */ + private array $sources = []; + + /** + * Tracks the nesting level of agent calls. + */ + private int $nestingLevel = 0; + public function __construct( private readonly ToolboxInterface $toolbox, private readonly ToolResultConverter $resultConverter = new ToolResultConverter(), private readonly ?EventDispatcherInterface $eventDispatcher = null, private readonly bool $keepToolMessages = false, + private readonly bool $keepToolSources = false, ) { } @@ -87,6 +102,7 @@ private function isFlatStringArray(array $tools): bool private function handleToolCallsCallback(Output $output): \Closure { return function (ToolCallResult $result, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResultInterface { + ++$this->nestingLevel; $messages = $this->keepToolMessages ? $output->getMessageBag() : clone $output->getMessageBag(); if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->getContent()) { @@ -101,6 +117,7 @@ private function handleToolCallsCallback(Output $output): \Closure foreach ($toolCalls as $toolCall) { $results[] = $toolResult = $this->toolbox->execute($toolCall); $messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($toolResult))); + array_push($this->sources, ...$toolResult->getSources()); } $event = new ToolCallsExecuted(...$results); @@ -109,6 +126,12 @@ private function handleToolCallsCallback(Output $output): \Closure $result = $event->hasResult() ? $event->getResult() : $this->agent->call($messages, $output->getOptions()); } while ($result instanceof ToolCallResult); + --$this->nestingLevel; + if ($this->keepToolSources && 0 === $this->nestingLevel) { + $result->getMetadata()->add('sources', $this->sources); + $this->sources = []; + } + return $result; }; } diff --git a/src/agent/src/Toolbox/Source/HasSourcesInterface.php b/src/agent/src/Toolbox/Source/HasSourcesInterface.php new file mode 100644 index 000000000..118c4d2b9 --- /dev/null +++ b/src/agent/src/Toolbox/Source/HasSourcesInterface.php @@ -0,0 +1,17 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Source; + +interface HasSourcesInterface +{ + public function setSourceMap(SourceMap $sourceMap): void; +} diff --git a/src/agent/src/Toolbox/Source/HasSourcesTrait.php b/src/agent/src/Toolbox/Source/HasSourcesTrait.php new file mode 100644 index 000000000..b0e22b7d1 --- /dev/null +++ b/src/agent/src/Toolbox/Source/HasSourcesTrait.php @@ -0,0 +1,32 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Source; + +trait HasSourcesTrait +{ + private SourceMap $sourceMap; + + public function setSourceMap(SourceMap $sourceMap): void + { + $this->sourceMap = $sourceMap; + } + + public function getSourceMap(): SourceMap + { + return $this->sourceMap ??= new SourceMap(); + } + + private function addSource(Source $source): void + { + $this->getSourceMap()->addSource($source); + } +} diff --git a/src/agent/src/Toolbox/Source/Source.php b/src/agent/src/Toolbox/Source/Source.php new file mode 100644 index 000000000..599164afc --- /dev/null +++ b/src/agent/src/Toolbox/Source/Source.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Source; + +readonly class Source +{ + public function __construct( + private string $name, + private string $reference, + private string $content, + ) { + } + + public function getName(): string + { + return $this->name; + } + + public function getReference(): string + { + return $this->reference; + } + + public function getContent(): string + { + return $this->content; + } +} diff --git a/src/agent/src/Toolbox/Source/SourceMap.php b/src/agent/src/Toolbox/Source/SourceMap.php new file mode 100644 index 000000000..eccb83e80 --- /dev/null +++ b/src/agent/src/Toolbox/Source/SourceMap.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Source; + +class SourceMap +{ + /** + * @var Source[] + */ + private array $sources = []; + + /** + * @return Source[] + */ + public function getSources(): array + { + return $this->sources; + } + + public function addSource(Source $source): void + { + $this->sources[] = $source; + } +} diff --git a/src/agent/src/Toolbox/Tool/Wikipedia.php b/src/agent/src/Toolbox/Tool/Wikipedia.php index 8cf85e4a7..748c2c235 100644 --- a/src/agent/src/Toolbox/Tool/Wikipedia.php +++ b/src/agent/src/Toolbox/Tool/Wikipedia.php @@ -12,6 +12,9 @@ namespace Symfony\AI\Agent\Toolbox\Tool; use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Source\HasSourcesInterface; +use Symfony\AI\Agent\Toolbox\Source\HasSourcesTrait; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\Contracts\HttpClient\HttpClientInterface; /** @@ -19,8 +22,10 @@ */ #[AsTool('wikipedia_search', description: 'Searches Wikipedia for a given query', method: 'search')] #[AsTool('wikipedia_article', description: 'Retrieves a Wikipedia article by its title', method: 'article')] -final readonly class Wikipedia +final class Wikipedia implements HasSourcesInterface { + use HasSourcesTrait; + public function __construct( private HttpClientInterface $httpClient, private string $locale = 'en', @@ -81,6 +86,10 @@ public function article(string $title): string $result .= \PHP_EOL; } + $this->addSource( + new Source($article['title'], $this->getUrl($article['title']), $article['extract']) + ); + return $result.'This is the content of article "'.$article['title'].'":'.\PHP_EOL.$article['extract']; } @@ -96,4 +105,9 @@ private function execute(array $query, ?string $locale = null): array return $response->toArray(); } + + private function getUrl(string $title): string + { + return \sprintf('https://%s.wikipedia.org/wiki/%s', $this->locale, str_replace(' ', '_', $title)); + } } diff --git a/src/agent/src/Toolbox/ToolResult.php b/src/agent/src/Toolbox/ToolResult.php index c074e8087..1481ed40c 100644 --- a/src/agent/src/Toolbox/ToolResult.php +++ b/src/agent/src/Toolbox/ToolResult.php @@ -11,6 +11,7 @@ namespace Symfony\AI\Agent\Toolbox; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\AI\Platform\Result\ToolCall; /** @@ -18,9 +19,13 @@ */ final readonly class ToolResult { + /** + * @param Source[] $sources + */ public function __construct( private ToolCall $toolCall, private mixed $result, + private array $sources = [], ) { } @@ -33,4 +38,12 @@ public function getResult(): mixed { return $this->result; } + + /** + * @return Source[] + */ + public function getSources(): array + { + return $this->sources; + } } diff --git a/src/agent/src/Toolbox/Toolbox.php b/src/agent/src/Toolbox/Toolbox.php index 0762a2ce1..b22d68bed 100644 --- a/src/agent/src/Toolbox/Toolbox.php +++ b/src/agent/src/Toolbox/Toolbox.php @@ -19,6 +19,8 @@ use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionExceptionInterface; use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Agent\Toolbox\Source\HasSourcesInterface; +use Symfony\AI\Agent\Toolbox\Source\SourceMap; use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Tool\Tool; @@ -82,7 +84,17 @@ public function execute(ToolCall $toolCall): ToolResult $arguments = $this->argumentResolver->resolveArguments($metadata, $toolCall); $this->eventDispatcher?->dispatch(new ToolCallArgumentsResolved($tool, $metadata, $arguments)); - $result = new ToolResult($toolCall, $tool->{$metadata->getReference()->getMethod()}(...$arguments)); + + if ($tool instanceof HasSourcesInterface) { + $tool->setSourceMap($sourceMap = new SourceMap()); + } + + $result = new ToolResult( + $toolCall, + $tool->{$metadata->getReference()->getMethod()}(...$arguments), + $tool instanceof HasSourcesInterface ? $sourceMap->getSources() : [], + ); + $this->eventDispatcher?->dispatch(new ToolCallSucceeded($tool, $metadata, $arguments, $result)); } catch (ToolExecutionExceptionInterface $e) { $this->eventDispatcher?->dispatch(new ToolCallFailed($tool, $metadata, $arguments ?? [], $e)); diff --git a/src/agent/tests/Toolbox/AgentProcessorTest.php b/src/agent/tests/Toolbox/AgentProcessorTest.php index 819326632..5ea535564 100644 --- a/src/agent/tests/Toolbox/AgentProcessorTest.php +++ b/src/agent/tests/Toolbox/AgentProcessorTest.php @@ -12,17 +12,24 @@ namespace Symfony\AI\Agent\Tests\Toolbox; use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Agent; use Symfony\AI\Agent\AgentInterface; use Symfony\AI\Agent\Input; use Symfony\AI\Agent\Output; use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\AI\Agent\Toolbox\ToolboxInterface; use Symfony\AI\Agent\Toolbox\ToolResult; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Result\DeferredResult; +use Symfony\AI\Platform\Result\InMemoryRawResult; +use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Result\ToolCallResult; +use Symfony\AI\Platform\Test\PlainConverter; use Symfony\AI\Platform\Tool\ExecutionReference; use Symfony\AI\Platform\Tool\Tool; @@ -120,4 +127,110 @@ public function testProcessOutputWithToolCallResponseForgettingMessages() $this->assertCount(0, $messageBag); } + + public function testSourcesEndUpInResultMetadataWithSettingOn() + { + $toolCall = new ToolCall('call_1234', 'tool_sources', ['arg1' => 'value1']); + $source1 = new Source('Relevant Article 1', 'http://example.com/article1', 'Content of article about the topic'); + $source2 = new Source('Relevant Article 2', 'http://example.com/article2', 'More content of article about the topic'); + $toolbox = $this->createMock(ToolboxInterface::class); + $toolbox + ->expects($this->once()) + ->method('execute') + ->willReturn(new ToolResult($toolCall, 'Response based on the two articles.', [$source1, $source2])); + + $messageBag = new MessageBag(); + $result = new ToolCallResult($toolCall); + + $agent = $this->createMock(AgentInterface::class); + $agent + ->expects($this->once()) + ->method('call') + ->willReturn(new TextResult('Final response based on the two articles.')); + + $processor = new AgentProcessor($toolbox, keepToolSources: true); + $processor->setAgent($agent); + + $output = new Output('gpt-4', $result, $messageBag); + + $processor->processOutput($output); + + $metadata = $output->getResult()->getMetadata(); + $this->assertTrue($metadata->has('sources')); + $this->assertCount(2, $metadata->get('sources')); + $this->assertSame([$source1, $source2], $metadata->get('sources')); + } + + public function testSourcesDoNotEndUpInResultMetadataWithSettingOff() + { + $toolCall = new ToolCall('call_1234', 'tool_sources', ['arg1' => 'value1']); + $source1 = new Source('Relevant Article 1', 'http://example.com/article1', 'Content of article about the topic'); + $source2 = new Source('Relevant Article 2', 'http://example.com/article2', 'More content of article about the topic'); + $toolbox = $this->createMock(ToolboxInterface::class); + $toolbox + ->expects($this->once()) + ->method('execute') + ->willReturn(new ToolResult($toolCall, 'Response based on the two articles.', [$source1, $source2])); + + $messageBag = new MessageBag(); + $result = new ToolCallResult($toolCall); + + $agent = $this->createMock(AgentInterface::class); + $agent + ->expects($this->once()) + ->method('call') + ->willReturn(new TextResult('Final response based on the two articles.')); + + $processor = new AgentProcessor($toolbox, keepToolSources: false); + $processor->setAgent($agent); + + $output = new Output('gpt-4', $result, $messageBag); + + $processor->processOutput($output); + + $metadata = $output->getResult()->getMetadata(); + $this->assertFalse($metadata->has('sources')); + } + + public function testSourcesGetCollectedAcrossConsecutiveToolCalls() + { + $toolCall1 = new ToolCall('call_1234', 'tool_sources', ['arg1' => 'value1']); + $source1 = new Source('Relevant Article 1', 'http://example.com/article1', 'Content of article about the topic'); + $toolCall2 = new ToolCall('call_5678', 'tool_sources', ['arg1' => 'value2']); + $source2 = new Source('Relevant Article 2', 'http://example.com/article2', 'More content of article about the topic'); + + $toolbox = $this->createMock(ToolboxInterface::class); + $toolbox + ->expects($this->exactly(2)) + ->method('execute') + ->willReturnOnConsecutiveCalls( + new ToolResult($toolCall1, 'Response based on the first article.', [$source1]), + new ToolResult($toolCall2, 'Response based on the second article.', [$source2]) + ); + + $messageBag = new MessageBag(); + $result = new ToolCallResult($toolCall1); + + $platform = $this->createMock(PlatformInterface::class); + $platform + ->expects($this->exactly(2)) + ->method('invoke') + ->willReturnOnConsecutiveCalls( + new DeferredResult(new PlainConverter(new ToolCallResult($toolCall2)), new InMemoryRawResult()), + new DeferredResult(new PlainConverter(new TextResult('Final response based on both articles.')), new InMemoryRawResult()) + ); + + $processor = new AgentProcessor($toolbox, keepToolSources: true); + $agent = new Agent($platform, 'foo-bar', [$processor], [$processor]); + $processor->setAgent($agent); + + $output = new Output('gpt-4', $result, $messageBag); + + $processor->processOutput($output); + + $metadata = $output->getResult()->getMetadata(); + $this->assertTrue($metadata->has('sources')); + $this->assertCount(2, $metadata->get('sources')); + $this->assertSame([$source1, $source2], $metadata->get('sources')); + } } diff --git a/src/agent/tests/Toolbox/ToolboxTest.php b/src/agent/tests/Toolbox/ToolboxTest.php index e5fdae95e..ecb0a16b1 100644 --- a/src/agent/tests/Toolbox/ToolboxTest.php +++ b/src/agent/tests/Toolbox/ToolboxTest.php @@ -17,6 +17,7 @@ use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionExceptionInterface; use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Agent\Toolbox\Source\Source; use Symfony\AI\Agent\Toolbox\Toolbox; use Symfony\AI\Agent\Toolbox\ToolFactory\ChainFactory; use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory; @@ -30,6 +31,7 @@ use Symfony\AI\Fixtures\Tool\ToolNoParams; use Symfony\AI\Fixtures\Tool\ToolOptionalParam; use Symfony\AI\Fixtures\Tool\ToolRequiredParams; +use Symfony\AI\Fixtures\Tool\ToolSources; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Tool\ExecutionReference; use Symfony\AI\Platform\Tool\Tool; @@ -283,4 +285,16 @@ public function testToolboxMapWithOverrideViaChain() $this->assertEquals($expected, $toolbox->getTools()); } + + public function testSourcesGetFromToolIntoResult() + { + $toolbox = new Toolbox([new ToolSources()]); + $result = $toolbox->execute(new ToolCall('call_1234', 'tool_sources', ['query' => 'random'])); + + $this->assertCount(1, $result->getSources()); + $this->assertInstanceOf(Source::class, $source = $result->getSources()[0]); + $this->assertSame('Relevant Article', $source->getName()); + $this->assertSame('https://example.com/relevant-article', $source->getReference()); + $this->assertSame('Content of that relevant article.', $source->getContent()); + } }