diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index c3da28bb1..df5a7f688 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -104,6 +104,10 @@ jobs: composer-options: "--no-scripts" working-directory: demo + - name: Link local packages + working-directory: demo + run: ../link + - run: composer run-script auto-scripts --no-interaction working-directory: demo diff --git a/demo/AGENTS.md b/demo/AGENTS.md index 2da757cc0..b7a5f2926 100644 --- a/demo/AGENTS.md +++ b/demo/AGENTS.md @@ -31,8 +31,8 @@ composer install echo "OPENAI_API_KEY='sk-...'" > .env.local # Initialize vector store -symfony console app:blog:embed -vv -symfony console app:blog:query +symfony console ai:store:index blog -vv +symfony console ai:store:retrieve blog "Week of Symfony" # Start server symfony serve -d diff --git a/demo/CLAUDE.md b/demo/CLAUDE.md index 28d310398..82893752e 100644 --- a/demo/CLAUDE.md +++ b/demo/CLAUDE.md @@ -39,7 +39,7 @@ echo "OPENAI_API_KEY='sk-...'" > .env.local symfony console ai:store:index blog -vv # Test vector store -symfony console app:blog:query +symfony console ai:store:retrieve blog "Week of Symfony" # Start development server symfony serve -d diff --git a/demo/README.md b/demo/README.md index 0a5b7441c..0efcf8a1f 100644 --- a/demo/README.md +++ b/demo/README.md @@ -77,10 +77,10 @@ To initialize the Chroma DB, you need to run the following command: symfony console ai:store:index blog -vv ``` -Now you should be able to run the test command and get some results: +Now you should be able to retrieve documents from the store: ```shell -symfony console app:blog:query +symfony console ai:store:retrieve blog "Week of Symfony" ``` **Don't forget to set up the project in your favorite IDE or editor.** diff --git a/demo/config/packages/ai.yaml b/demo/config/packages/ai.yaml index a795c5d5a..06499c33e 100644 --- a/demo/config/packages/ai.yaml +++ b/demo/config/packages/ai.yaml @@ -100,6 +100,10 @@ ai: - 'Symfony\AI\Store\Document\Transformer\TextTrimTransformer' vectorizer: 'ai.vectorizer.openai' store: 'ai.store.chroma_db.symfonycon' + retriever: + blog: + vectorizer: 'ai.vectorizer.openai' + store: 'ai.store.chroma_db.symfonycon' services: _defaults: diff --git a/demo/src/Blog/Command/QueryCommand.php b/demo/src/Blog/Command/QueryCommand.php deleted file mode 100644 index 434ab7f2f..000000000 --- a/demo/src/Blog/Command/QueryCommand.php +++ /dev/null @@ -1,71 +0,0 @@ - - * - * For the full copyright and license information, please view the LICENSE - * file that was distributed with this source code. - */ - -namespace App\Blog\Command; - -use Codewithkyrian\ChromaDB\Client; -use Symfony\AI\Store\Document\VectorizerInterface; -use Symfony\Component\Console\Attribute\AsCommand; -use Symfony\Component\Console\Command\Command; -use Symfony\Component\Console\Style\SymfonyStyle; -use Symfony\Component\DependencyInjection\Attribute\Autowire; - -#[AsCommand('app:blog:query', description: 'Test command for querying the blog collection in Chroma DB.')] -final readonly class QueryCommand -{ - public function __construct( - private Client $chromaClient, - #[Autowire(service: 'ai.vectorizer.openai')] - private VectorizerInterface $vectorizer, - ) { - } - - public function __invoke(SymfonyStyle $io): int - { - $io->title('Testing Chroma DB Connection'); - - $io->comment('Connecting to Chroma DB ...'); - $collection = $this->chromaClient->getOrCreateCollection('symfony_blog'); - $io->table(['Key', 'Value'], [ - ['ChromaDB Version', $this->chromaClient->version()], - ['Collection Name', $collection->name], - ['Collection ID', $collection->id], - ['Total Documents', $collection->count()], - ]); - - $search = $io->ask('What do you want to know about?', 'New Symfony Features'); - $io->comment(\sprintf('Converting "%s" to vector & searching in Chroma DB ...', $search)); - $io->comment('Results are limited to 4 most similar documents.'); - - $vector = $this->vectorizer->vectorize((string) $search); - $queryResponse = $collection->query( - queryEmbeddings: [$vector->getData()], - nResults: 4, - ); - - if (1 === \count($queryResponse->ids, \COUNT_RECURSIVE)) { - $io->error('No results found!'); - - return Command::FAILURE; - } - - foreach ($queryResponse->ids[0] as $i => $id) { - /* @phpstan-ignore-next-line */ - $io->section($queryResponse->metadatas[0][$i]['title']); - /* @phpstan-ignore-next-line */ - $io->block($queryResponse->metadatas[0][$i]['description']); - } - - $io->success('Chroma DB Connection & Similarity Search Test Successful!'); - - return Command::SUCCESS; - } -} diff --git a/docs/bundles/ai-bundle.rst b/docs/bundles/ai-bundle.rst index a0b959442..0ddecfab5 100644 --- a/docs/bundles/ai-bundle.rst +++ b/docs/bundles/ai-bundle.rst @@ -972,6 +972,69 @@ Benefits of Configured Vectorizers * **Consistency**: Ensure all indexers using the same vectorizer have identical embedding configuration * **Maintainability**: Change vectorizer settings in one place +Retrievers +---------- + +Retrievers are the opposite of indexers. While indexers populate a vector store with documents, +retrievers allow you to search for documents in a store based on a query string. +They vectorize the query and retrieve similar documents from the store. + +Configuring Retrievers +~~~~~~~~~~~~~~~~~~~~~~ + +Retrievers are defined in the ``retriever`` section of your configuration: + +.. code-block:: yaml + + ai: + retriever: + default: + vectorizer: 'ai.vectorizer.openai_small' + store: 'ai.store.chroma_db.default' + + research: + vectorizer: 'ai.vectorizer.mistral_embed' + store: 'ai.store.memory.research' + +Using Retrievers +~~~~~~~~~~~~~~~~ + +The retriever can be injected into your services using the ``RetrieverInterface``:: + + use Symfony\AI\Store\RetrieverInterface; + + final readonly class MyService + { + public function __construct( + private RetrieverInterface $retriever, + ) { + } + + public function search(string $query): array + { + $documents = []; + foreach ($this->retriever->retrieve($query) as $document) { + $documents[] = $document; + } + + return $documents; + } + } + +When you have multiple retrievers configured, you can use the ``#[Autowire]`` attribute to inject a specific one:: + + use Symfony\AI\Store\RetrieverInterface; + use Symfony\Component\DependencyInjection\Attribute\Autowire; + + final readonly class ResearchService + { + public function __construct( + #[Autowire(service: 'ai.retriever.research')] + private RetrieverInterface $retriever, + ) { + } + } + Profiler -------- diff --git a/docs/components/store.rst b/docs/components/store.rst index 346920b59..7b518944c 100644 --- a/docs/components/store.rst +++ b/docs/components/store.rst @@ -33,7 +33,34 @@ used vector store:: $document = new TextDocument('This is a sample document.'); $indexer->index($document); -You can find more advanced usage in combination with an Agent using the store for RAG in the examples folder: +You can find more advanced usage in combination with an Agent using the store for RAG in the examples folder. + +Retrieving +---------- + +The opposite of indexing is retrieving. The :class:`Symfony\\AI\\Store\\Retriever` is a higher level feature that allows you to +search for documents in a store based on a query string. It vectorizes the query and retrieves similar documents from the store:: + + use Symfony\AI\Store\Retriever; + + $retriever = new Retriever($vectorizer, $store); + $documents = $retriever->retrieve('What is the capital of France?'); + + foreach ($documents as $document) { + echo $document->metadata->get('source'); + } + +The retriever accepts optional parameters to customize the retrieval: + +* ``$options``: An array of options to pass to the underlying store query (e.g., limit, filters) + +Example Usage +~~~~~~~~~~~~~ + +* `Basic Retriever Example`_ + +Similarity Search Examples +~~~~~~~~~~~~~~~~~~~~~~~~~~ * `Similarity Search with Cloudflare (RAG)`_ * `Similarity Search with Manticore (RAG)`_ @@ -129,6 +156,7 @@ This leads to a store implementing two methods:: } .. _`Retrieval Augmented Generation`: https://en.wikipedia.org/wiki/Retrieval-augmented_generation +.. _`Basic Retriever Example`: https://github.com/symfony/ai/blob/main/examples/retriever/basic.php .. _`Similarity Search with Cloudflare (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/cloudflare.php .. _`Similarity Search with Manticore (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/manticore.php .. _`Similarity Search with MariaDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/mariadb-gemini.php diff --git a/examples/retriever/basic.php b/examples/retriever/basic.php new file mode 100644 index 000000000..55bbb1dbd --- /dev/null +++ b/examples/retriever/basic.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Store\Bridge\Local\InMemoryStore; +use Symfony\AI\Store\Document\Loader\TextFileLoader; +use Symfony\AI\Store\Document\Transformer\TextSplitTransformer; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Indexer; +use Symfony\AI\Store\Retriever; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$store = new InMemoryStore(); + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$vectorizer = new Vectorizer($platform, 'text-embedding-3-small'); + +$indexer = new Indexer( + loader: new TextFileLoader(), + vectorizer: $vectorizer, + store: $store, + source: [ + dirname(__DIR__, 2).'/fixtures/movies/gladiator.md', + dirname(__DIR__, 2).'/fixtures/movies/inception.md', + dirname(__DIR__, 2).'/fixtures/movies/jurassic-park.md', + ], + transformers: [ + new TextSplitTransformer(chunkSize: 500, overlap: 100), + ], +); +$indexer->index(); + +$retriever = new Retriever( + vectorizer: $vectorizer, + store: $store, +); + +echo "Searching for: 'Roman gladiator revenge'\n\n"; +$results = $retriever->retrieve('Roman gladiator revenge', ['maxItems' => 1]); + +foreach ($results as $i => $document) { + echo sprintf("%d. Document ID: %s\n", $i + 1, $document->id); + echo sprintf(" Score: %s\n", $document->score ?? 'n/a'); + echo sprintf(" Source: %s\n\n", $document->metadata->getSource() ?? 'unknown'); +} diff --git a/examples/retriever/movies.php b/examples/retriever/movies.php new file mode 100644 index 000000000..a4699c4f7 --- /dev/null +++ b/examples/retriever/movies.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Fixtures\Movies; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Store\Bridge\Local\InMemoryStore; +use Symfony\AI\Store\Document\Loader\InMemoryLoader; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Indexer; +use Symfony\AI\Store\Retriever; +use Symfony\Component\Uid\Uuid; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$store = new InMemoryStore(); + +$documents = []; +foreach (Movies::all() as $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata($movie), + ); +} + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$vectorizer = new Vectorizer($platform, 'text-embedding-3-small', logger()); + +$indexer = new Indexer(new InMemoryLoader($documents), $vectorizer, $store, logger: logger()); +$indexer->index(); + +$retriever = new Retriever($vectorizer, $store, logger()); + +echo "Searching for movies about 'crime family mafia'\n"; +echo "================================================\n\n"; + +$results = $retriever->retrieve('crime family mafia'); + +foreach ($results as $i => $document) { + $title = $document->metadata['title']; + $director = $document->metadata['director']; + $score = $document->score; + + echo sprintf("%d. %s (Director: %s)\n", $i + 1, $title, $director); + echo sprintf(" Score: %.4f\n\n", $score ?? 0); +} diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 139fc560d..e7776f892 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -1089,6 +1089,22 @@ ->end() ->end() ->end() + ->arrayNode('retriever') + ->info('Retrievers for fetching documents from a vector store based on a query') + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('vectorizer') + ->info('Service name of vectorizer') + ->defaultValue(VectorizerInterface::class) + ->end() + ->stringNode('store') + ->info('Service name of store') + ->defaultValue(StoreInterface::class) + ->end() + ->end() + ->end() + ->end() ->end() ->validate() ->ifTrue(function ($v) { diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index 1baacccc9..a0f0e24c6 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -68,6 +68,7 @@ use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactoryInterface; use Symfony\AI\Store\Command\DropStoreCommand; use Symfony\AI\Store\Command\IndexCommand; +use Symfony\AI\Store\Command\RetrieveCommand; use Symfony\AI\Store\Command\SetupStoreCommand; return static function (ContainerConfigurator $container): void { @@ -220,6 +221,11 @@ tagged_locator('ai.indexer', 'name'), ]) ->tag('console.command') + ->set('ai.command.retrieve', RetrieveCommand::class) + ->args([ + tagged_locator('ai.retriever', 'name'), + ]) + ->tag('console.command') ->set('ai.command.platform_invoke', PlatformInvokeCommand::class) ->args([ tagged_locator('ai.platform', 'name'), diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 57d372e59..e0923fbc0 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -104,6 +104,8 @@ use Symfony\AI\Store\Indexer; use Symfony\AI\Store\IndexerInterface; use Symfony\AI\Store\ManagedStoreInterface; +use Symfony\AI\Store\Retriever; +use Symfony\AI\Store\RetrieverInterface; use Symfony\AI\Store\StoreInterface; use Symfony\Component\Clock\ClockInterface; use Symfony\Component\Config\Definition\Configurator\DefinitionConfigurator; @@ -262,6 +264,13 @@ public function loadExtension(array $config, ContainerConfigurator $container, C $builder->setAlias(IndexerInterface::class, 'ai.indexer.'.$indexerName); } + foreach ($config['retriever'] ?? [] as $retrieverName => $retriever) { + $this->processRetrieverConfig($retrieverName, $retriever, $builder); + } + if (1 === \count($config['retriever'] ?? []) && isset($retrieverName)) { + $builder->setAlias(RetrieverInterface::class, 'ai.retriever.'.$retrieverName); + } + $builder->registerAttributeForAutoconfiguration(AsTool::class, static function (ChildDefinition $definition, AsTool $attribute): void { $definition->addTag('ai.tool', [ 'name' => $attribute->name, @@ -1866,6 +1875,23 @@ private function processIndexerConfig(int|string $name, array $config, Container $container->registerAliasForArgument($serviceId, IndexerInterface::class, (new Target((string) $name))->getParsedName()); } + /** + * @param array $config + */ + private function processRetrieverConfig(int|string $name, array $config, ContainerBuilder $container): void + { + $definition = new Definition(Retriever::class, [ + new Reference($config['vectorizer']), + new Reference($config['store']), + new Reference('logger', ContainerInterface::IGNORE_ON_INVALID_REFERENCE), + ]); + $definition->addTag('ai.retriever', ['name' => $name]); + + $serviceId = 'ai.retriever.'.$name; + $container->setDefinition($serviceId, $definition); + $container->registerAliasForArgument($serviceId, RetrieverInterface::class, (new Target((string) $name))->getParsedName()); + } + /** * @param array $config */ diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index 6dff314c4..42cac8aca 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -48,6 +48,7 @@ use Symfony\AI\Store\Document\VectorizerInterface; use Symfony\AI\Store\IndexerInterface; use Symfony\AI\Store\ManagedStoreInterface; +use Symfony\AI\Store\RetrieverInterface; use Symfony\AI\Store\StoreInterface; use Symfony\Component\Clock\ClockInterface; use Symfony\Component\Config\Definition\Exception\InvalidConfigurationException; @@ -3293,6 +3294,94 @@ public function testInjectionIndexerAliasIsRegistered() $this->assertTrue($container->hasAlias(IndexerInterface::class.' $another')); } + public function testRetrieverWithConfiguredVectorizer() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'memory' => [ + 'my_store' => [], + ], + ], + 'vectorizer' => [ + 'my_vectorizer' => [ + 'platform' => 'my_platform_service_id', + 'model' => 'text-embedding-3-small', + ], + ], + 'retriever' => [ + 'my_retriever' => [ + 'vectorizer' => 'ai.vectorizer.my_vectorizer', + 'store' => 'ai.store.memory.my_store', + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.retriever.my_retriever')); + $this->assertTrue($container->hasDefinition('ai.vectorizer.my_vectorizer')); + + $retrieverDefinition = $container->getDefinition('ai.retriever.my_retriever'); + $arguments = $retrieverDefinition->getArguments(); + + $this->assertInstanceOf(Reference::class, $arguments[0]); + $this->assertSame('ai.vectorizer.my_vectorizer', (string) $arguments[0]); + + $this->assertInstanceOf(Reference::class, $arguments[1]); + $this->assertSame('ai.store.memory.my_store', (string) $arguments[1]); + + $this->assertInstanceOf(Reference::class, $arguments[2]); // logger + $this->assertSame('logger', (string) $arguments[2]); + } + + public function testRetrieverAliasIsRegistered() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'memory' => [ + 'my_store' => [], + ], + ], + 'retriever' => [ + 'my_retriever' => [ + 'vectorizer' => 'my_vectorizer_service', + 'store' => 'ai.store.memory.my_store', + ], + 'another' => [ + 'vectorizer' => 'my_vectorizer_service', + 'store' => 'ai.store.memory.my_store', + ], + ], + ], + ]); + + $this->assertTrue($container->hasAlias(RetrieverInterface::class.' $myRetriever')); + $this->assertTrue($container->hasAlias(RetrieverInterface::class.' $another')); + } + + public function testSingleRetrieverCreatesDefaultAlias() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'memory' => [ + 'my_store' => [], + ], + ], + 'retriever' => [ + 'default' => [ + 'vectorizer' => 'my_vectorizer_service', + 'store' => 'ai.store.memory.my_store', + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.retriever.default')); + $this->assertTrue($container->hasAlias(RetrieverInterface::class)); + } + public function testValidMultiAgentConfiguration() { $container = $this->buildContainer([ diff --git a/src/store/src/Command/RetrieveCommand.php b/src/store/src/Command/RetrieveCommand.php new file mode 100644 index 000000000..ea37a385d --- /dev/null +++ b/src/store/src/Command/RetrieveCommand.php @@ -0,0 +1,145 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Command; + +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\AI\Store\RetrieverInterface; +use Symfony\Component\Console\Attribute\AsCommand; +use Symfony\Component\Console\Command\Command; +use Symfony\Component\Console\Completion\CompletionInput; +use Symfony\Component\Console\Completion\CompletionSuggestions; +use Symfony\Component\Console\Input\InputArgument; +use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Input\InputOption; +use Symfony\Component\Console\Output\OutputInterface; +use Symfony\Component\Console\Style\SymfonyStyle; +use Symfony\Component\DependencyInjection\ServiceLocator; + +/** + * @author Oskar Stark + */ +#[AsCommand( + name: 'ai:store:retrieve', + description: 'Retrieve documents from a store', +)] +final class RetrieveCommand extends Command +{ + /** + * @param ServiceLocator $retrievers + */ + public function __construct( + private readonly ServiceLocator $retrievers, + ) { + parent::__construct(); + } + + public function complete(CompletionInput $input, CompletionSuggestions $suggestions): void + { + if ($input->mustSuggestArgumentValuesFor('retriever')) { + $suggestions->suggestValues(array_keys($this->retrievers->getProvidedServices())); + } + } + + protected function configure(): void + { + $this + ->addArgument('retriever', InputArgument::REQUIRED, 'Name of the retriever to use') + ->addArgument('query', InputArgument::OPTIONAL, 'Search query') + ->addOption('limit', 'l', InputOption::VALUE_REQUIRED, 'Maximum number of results to return', '10') + ->setHelp(<<<'EOF' +The %command.name% command retrieves documents from a store using the specified retriever. + +Basic usage: + php %command.full_name% blog "search query" + +Interactive mode (prompts for query): + php %command.full_name% blog + +Limit results: + php %command.full_name% blog "search query" --limit=5 +EOF + ) + ; + } + + protected function execute(InputInterface $input, OutputInterface $output): int + { + $io = new SymfonyStyle($input, $output); + + $retriever = $input->getArgument('retriever'); + + if (!$this->retrievers->has($retriever)) { + throw new RuntimeException(\sprintf('The "%s" retriever does not exist.', $retriever)); + } + + $query = $input->getArgument('query'); + if (null === $query) { + $query = $io->ask('What do you want to search for?'); + if (null === $query || '' === $query) { + $io->error('A search query is required.'); + + return Command::FAILURE; + } + } + + $limit = (int) $input->getOption('limit'); + + $io->title(\sprintf('Retrieving documents using "%s" retriever', $retriever)); + $io->comment(\sprintf('Searching for: "%s"', $query)); + + try { + $retrieverService = $this->retrievers->get($retriever); + $documents = $retrieverService->retrieve($query, ['maxItems' => $limit]); + + $count = 0; + foreach ($documents as $document) { + ++$count; + $io->section(\sprintf('Result #%d', $count)); + + $tableData = [ + ['ID', (string) $document->id], + ['Score', $document->score ?? 'n/a'], + ]; + + if ($document->metadata->hasSource()) { + $tableData[] = ['Source', $document->metadata->getSource()]; + } + + if ($document->metadata->hasText()) { + $text = $document->metadata->getText(); + if (\strlen($text) > 200) { + $text = substr($text, 0, 200).'...'; + } + $tableData[] = ['Text', $text]; + } + + $io->table([], $tableData); + + if ($count >= $limit) { + break; + } + } + + if (0 === $count) { + $io->warning('No results found.'); + + return Command::SUCCESS; + } + + $io->success(\sprintf('Found %d result(s) using "%s" retriever.', $count, $retriever)); + } catch (\Exception $e) { + throw new RuntimeException(\sprintf('An error occurred while retrieving with "%s": ', $retriever).$e->getMessage(), previous: $e); + } + + return Command::SUCCESS; + } +} diff --git a/src/store/src/Retriever.php b/src/store/src/Retriever.php new file mode 100644 index 000000000..79e91297c --- /dev/null +++ b/src/store/src/Retriever.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Document\VectorizerInterface; + +/** + * @author Oskar Stark + */ +final class Retriever implements RetrieverInterface +{ + public function __construct( + private readonly VectorizerInterface $vectorizer, + private readonly StoreInterface $store, + private readonly LoggerInterface $logger = new NullLogger(), + ) { + } + + /** + * @return iterable + */ + public function retrieve(string $query, array $options = []): iterable + { + $this->logger->debug('Starting document retrieval', ['query' => $query, 'options' => $options]); + + $vector = $this->vectorizer->vectorize($query); + + $this->logger->debug('Query vectorized, searching store'); + + $documents = $this->store->query($vector, $options); + + $count = 0; + foreach ($documents as $document) { + ++$count; + yield $document; + } + + $this->logger->debug('Document retrieval completed', ['retrieved_count' => $count]); + } +} diff --git a/src/store/src/RetrieverInterface.php b/src/store/src/RetrieverInterface.php new file mode 100644 index 000000000..26eea5aa6 --- /dev/null +++ b/src/store/src/RetrieverInterface.php @@ -0,0 +1,35 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Symfony\AI\Store\Document\VectorDocument; + +/** + * Retrieves documents from a vector store based on a query string. + * + * The opposite of IndexerInterface - while the Indexer loads, transforms, vectorizes and stores documents, + * the Retriever vectorizes a query and retrieves similar documents from the store. + * + * @author Oskar Stark + */ +interface RetrieverInterface +{ + /** + * Retrieve documents from the store that are similar to the given query. + * + * @param string $query The search query to vectorize and use for similarity search + * @param array $options Options to pass to the store query (e.g., limit, filters) + * + * @return iterable The retrieved documents with similarity scores + */ + public function retrieve(string $query, array $options = []): iterable; +} diff --git a/src/store/tests/Command/RetrieveCommandTest.php b/src/store/tests/Command/RetrieveCommandTest.php new file mode 100644 index 000000000..b0ebc58b0 --- /dev/null +++ b/src/store/tests/Command/RetrieveCommandTest.php @@ -0,0 +1,319 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Command; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Command\RetrieveCommand; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\AI\Store\RetrieverInterface; +use Symfony\Component\Console\Tester\CommandTester; +use Symfony\Component\DependencyInjection\ServiceLocator; +use Symfony\Component\Uid\Uuid; + +/** + * @author Oskar Stark + */ +final class RetrieveCommandTest extends TestCase +{ + public function testCommandIsConfigured() + { + $command = new RetrieveCommand(new ServiceLocator([])); + + $this->assertSame('ai:store:retrieve', $command->getName()); + $this->assertSame('Retrieve documents from a store', $command->getDescription()); + + $definition = $command->getDefinition(); + $this->assertTrue($definition->hasArgument('retriever')); + $this->assertTrue($definition->hasArgument('query')); + $this->assertTrue($definition->hasOption('limit')); + + $retrieverArgument = $definition->getArgument('retriever'); + $this->assertSame('Name of the retriever to use', $retrieverArgument->getDescription()); + $this->assertTrue($retrieverArgument->isRequired()); + + $queryArgument = $definition->getArgument('query'); + $this->assertSame('Search query', $queryArgument->getDescription()); + $this->assertFalse($queryArgument->isRequired()); + + $limitOption = $definition->getOption('limit'); + $this->assertSame('Maximum number of results to return', $limitOption->getDescription()); + $this->assertSame('10', $limitOption->getDefault()); + } + + public function testCommandCannotRetrieveFromNonExistingRetriever() + { + $command = new RetrieveCommand(new ServiceLocator([])); + + $tester = new CommandTester($command); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('The "foo" retriever does not exist.'); + $tester->execute([ + 'retriever' => 'foo', + 'query' => 'test query', + ]); + } + + public function testCommandCanRetrieveDocuments() + { + $metadata = new Metadata(); + $metadata->setText('Test document content'); + $metadata->setSource('test-source.txt'); + + $document = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + $metadata, + 0.95, + ); + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->with('test query', ['maxItems' => 10]) + ->willReturn([$document]); + + $command = new RetrieveCommand(new ServiceLocator([ + 'blog' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'blog', + 'query' => 'test query', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('Retrieving documents using "blog" retriever', $display); + $this->assertStringContainsString('Searching for: "test query"', $display); + $this->assertStringContainsString('Result #1', $display); + $this->assertStringContainsString('0.95', $display); + $this->assertStringContainsString('test-source.txt', $display); + $this->assertStringContainsString('Test document content', $display); + $this->assertStringContainsString('Found 1 result(s) using "blog" retriever.', $display); + } + + public function testCommandCanRetrieveDocumentsWithCustomLimit() + { + $documents = []; + for ($i = 0; $i < 3; ++$i) { + $metadata = new Metadata(); + $metadata->setText('Document '.$i); + $documents[] = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + $metadata, + 0.9 - ($i * 0.1), + ); + } + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->with('my query', ['maxItems' => 5]) + ->willReturn($documents); + + $command = new RetrieveCommand(new ServiceLocator([ + 'products' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'products', + 'query' => 'my query', + '--limit' => '5', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('Result #1', $display); + $this->assertStringContainsString('Result #2', $display); + $this->assertStringContainsString('Result #3', $display); + $this->assertStringContainsString('Found 3 result(s) using "products" retriever.', $display); + } + + public function testCommandShowsWarningWhenNoResultsFound() + { + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->with('unknown query', ['maxItems' => 10]) + ->willReturn([]); + + $command = new RetrieveCommand(new ServiceLocator([ + 'articles' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'articles', + 'query' => 'unknown query', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('No results found.', $display); + } + + public function testCommandThrowsExceptionOnRetrieverError() + { + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->willThrowException(new RuntimeException('Connection failed')); + + $command = new RetrieveCommand(new ServiceLocator([ + 'docs' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('An error occurred while retrieving with "docs": Connection failed'); + $tester->execute([ + 'retriever' => 'docs', + 'query' => 'test', + ]); + } + + public function testCommandTruncatesLongText() + { + $longText = str_repeat('a', 300); + $metadata = new Metadata(); + $metadata->setText($longText); + + $document = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + $metadata, + 0.8, + ); + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->willReturn([$document]); + + $command = new RetrieveCommand(new ServiceLocator([ + 'test' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'test', + 'query' => 'search', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString(str_repeat('a', 200).'...', $display); + $this->assertStringNotContainsString(str_repeat('a', 201), $display); + } + + public function testCommandHandlesDocumentWithoutSourceOrText() + { + $document = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + new Metadata(), + 0.75, + ); + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->willReturn([$document]); + + $command = new RetrieveCommand(new ServiceLocator([ + 'minimal' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'minimal', + 'query' => 'test', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('Result #1', $display); + $this->assertStringContainsString('0.75', $display); + $this->assertStringContainsString('Found 1 result(s)', $display); + } + + public function testCommandHandlesDocumentWithoutScore() + { + $metadata = new Metadata(); + $metadata->setText('Some content'); + + $document = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + $metadata, + ); + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->willReturn([$document]); + + $command = new RetrieveCommand(new ServiceLocator([ + 'noscore' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'noscore', + 'query' => 'test', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('n/a', $display); + } + + public function testCommandRespectsLimit() + { + $documents = []; + for ($i = 0; $i < 10; ++$i) { + $metadata = new Metadata(); + $metadata->setText('Document '.$i); + $documents[] = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + $metadata, + ); + } + + $retriever = $this->createMock(RetrieverInterface::class); + $retriever->expects($this->once()) + ->method('retrieve') + ->with('test', ['maxItems' => 3]) + ->willReturn($documents); + + $command = new RetrieveCommand(new ServiceLocator([ + 'many' => static fn (): RetrieverInterface => $retriever, + ])); + + $tester = new CommandTester($command); + $tester->execute([ + 'retriever' => 'many', + 'query' => 'test', + '--limit' => '3', + ]); + + $display = $tester->getDisplay(); + $this->assertStringContainsString('Result #1', $display); + $this->assertStringContainsString('Result #2', $display); + $this->assertStringContainsString('Result #3', $display); + $this->assertStringNotContainsString('Result #4', $display); + $this->assertStringContainsString('Found 3 result(s)', $display); + } +} diff --git a/src/store/tests/Double/TestStore.php b/src/store/tests/Double/TestStore.php index 7b86eef00..6786fda5f 100644 --- a/src/store/tests/Double/TestStore.php +++ b/src/store/tests/Double/TestStore.php @@ -13,7 +13,6 @@ use Symfony\AI\Platform\Vector\Vector; use Symfony\AI\Store\Document\VectorDocument; -use Symfony\AI\Store\Exception\RuntimeException; use Symfony\AI\Store\StoreInterface; final class TestStore implements StoreInterface @@ -33,6 +32,6 @@ public function add(VectorDocument ...$documents): void public function query(Vector $vector, array $options = []): iterable { - throw new RuntimeException('Not yet implemented.'); + return $this->documents; } } diff --git a/src/store/tests/RetrieverTest.php b/src/store/tests/RetrieverTest.php new file mode 100644 index 000000000..6702ec035 --- /dev/null +++ b/src/store/tests/RetrieverTest.php @@ -0,0 +1,100 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Result\VectorResult; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Retriever; +use Symfony\AI\Store\Tests\Double\PlatformTestHandler; +use Symfony\AI\Store\Tests\Double\TestStore; +use Symfony\Component\Uid\Uuid; + +/** + * @author Oskar Stark + */ +final class RetrieverTest extends TestCase +{ + public function testRetrieveReturnsDocuments() + { + $document1 = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + new Metadata(['title' => 'Document 1']), + ); + $document2 = new VectorDocument( + Uuid::v4(), + new Vector([0.4, 0.5, 0.6]), + new Metadata(['title' => 'Document 2']), + ); + + $store = new TestStore(); + $store->add($document1, $document2); + + $queryVector = new Vector([0.2, 0.3, 0.4]); + $vectorizer = new Vectorizer( + PlatformTestHandler::createPlatform(new VectorResult($queryVector)), + 'text-embedding-3-small' + ); + + $retriever = new Retriever($vectorizer, $store); + $results = iterator_to_array($retriever->retrieve('test query')); + + $this->assertCount(2, $results); + $this->assertInstanceOf(VectorDocument::class, $results[0]); + $this->assertInstanceOf(VectorDocument::class, $results[1]); + $this->assertSame('Document 1', $results[0]->metadata['title']); + $this->assertSame('Document 2', $results[1]->metadata['title']); + } + + public function testRetrieveWithEmptyStore() + { + $store = new TestStore(); + + $queryVector = new Vector([0.1, 0.2, 0.3]); + $vectorizer = new Vectorizer( + PlatformTestHandler::createPlatform(new VectorResult($queryVector)), + 'text-embedding-3-small' + ); + + $retriever = new Retriever($vectorizer, $store); + $results = iterator_to_array($retriever->retrieve('test query')); + + $this->assertCount(0, $results); + } + + public function testRetrievePassesOptionsToStore() + { + $document = new VectorDocument( + Uuid::v4(), + new Vector([0.1, 0.2, 0.3]), + new Metadata(['title' => 'Test Document']), + ); + + $store = new TestStore(); + $store->add($document); + + $queryVector = new Vector([0.2, 0.3, 0.4]); + $vectorizer = new Vectorizer( + PlatformTestHandler::createPlatform(new VectorResult($queryVector)), + 'text-embedding-3-small' + ); + + $retriever = new Retriever($vectorizer, $store); + $results = iterator_to_array($retriever->retrieve('test query', ['maxItems' => 10])); + + $this->assertCount(1, $results); + } +}