Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/store/composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"symfony/uid": "^7.3|^8.0"
},
"require-dev": {
"codewithkyrian/chromadb-php": "^0.2.1|^0.3|^0.4",
"codewithkyrian/chromadb-php": "^1.0",
"doctrine/dbal": "^3.3|^4.0",
"mongodb/mongodb": "^1.21|^2.0",
"phpstan/phpstan": "^2.0",
Expand Down
5 changes: 4 additions & 1 deletion src/store/src/Bridge/ChromaDb/Store.php
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ public function add(VectorDocument ...$documents): void
}

$collection = $this->client->getOrCreateCollection($this->collectionName);

// @phpstan-ignore argument.type (chromadb-php library has incorrect PHPDoc type for $metadatas parameter)
$collection->add($ids, $vectors, $metadata, $originalDocuments);
}

/**
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>} $options
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>, queryTexts?: array<string>} $options
*/
public function query(Vector $vector, array $options = []): iterable
{
Expand All @@ -65,6 +67,7 @@ public function query(Vector $vector, array $options = []): iterable
$collection = $this->client->getOrCreateCollection($this->collectionName);
$queryResponse = $collection->query(
queryEmbeddings: [$vector->getData()],
queryTexts: $options['queryTexts'] ?? null,
nResults: 4,
where: $options['where'] ?? null,
whereDocument: $options['whereDocument'] ?? null,
Expand Down
78 changes: 58 additions & 20 deletions src/store/src/Bridge/ChromaDb/Tests/StoreTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
namespace Symfony\AI\Store\Bridge\ChromaDb\Tests;

use Codewithkyrian\ChromaDB\Client;
use Codewithkyrian\ChromaDB\Generated\Responses\QueryItemsResponse;
use Codewithkyrian\ChromaDB\Resources\CollectionResource;
use Codewithkyrian\ChromaDB\Models\Collection;
use Codewithkyrian\ChromaDB\Responses\QueryItemsResponse;
use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\TestCase;
use Symfony\AI\Platform\Vector\Vector;
Expand All @@ -39,7 +39,7 @@ public function testAddDocumentsSuccessfully(
array $expectedMetadata,
array $expectedOriginalDocuments,
): void {
$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -146,7 +146,7 @@ public function testQueryWithoutFilters()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -159,7 +159,6 @@ public function testQueryWithoutFilters()
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
null, // where
null, // whereDocument
Expand Down Expand Up @@ -191,7 +190,7 @@ public function testQueryWithWhereFilter()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -204,7 +203,6 @@ public function testQueryWithWhereFilter()
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
['category' => 'technology'], // where
null, // whereDocument
Expand Down Expand Up @@ -235,7 +233,7 @@ public function testQueryWithWhereDocumentFilter()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -248,7 +246,6 @@ public function testQueryWithWhereDocumentFilter()
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
null, // where
['$contains' => 'machine learning'], // whereDocument
Expand Down Expand Up @@ -280,7 +277,7 @@ public function testQueryWithBothFilters()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -293,7 +290,6 @@ public function testQueryWithBothFilters()
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
['category' => 'AI', 'status' => 'published'], // where
['$contains' => 'neural networks'], // whereDocument
Expand Down Expand Up @@ -327,7 +323,7 @@ public function testQueryWithEmptyResults()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -340,7 +336,6 @@ public function testQueryWithEmptyResults()
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
['category' => 'nonexistent'], // where
null, // whereDocument
Expand All @@ -367,7 +362,7 @@ public function testQueryReturnsDistancesAsScore()
distances: [[0.123, 0.456]]
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -400,7 +395,7 @@ public function testQueryReturnsNullScoreWhenDistancesNotAvailable()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -442,7 +437,7 @@ public function testQueryWithVariousFilterCombinations(
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand All @@ -455,7 +450,6 @@ public function testQueryWithVariousFilterCombinations(
->with(
[[0.1, 0.2, 0.3]], // queryEmbeddings
null, // queryTexts
null, // queryImages
4, // nResults
$expectedWhere, // where
$expectedWhereDocument,// whereDocument
Expand All @@ -482,7 +476,7 @@ public function testQueryReturnsMetadatasEmbeddingsDistanceWithoutInclude()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -516,7 +510,7 @@ public function testQueryReturnsMetadatasEmbeddingsDistanceWithOnlyDocuments()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -550,7 +544,7 @@ public function testQueryReturnsMetadatasEmbeddingsDistanceWithAll()
distances: null
);

$collection = $this->createMock(CollectionResource::class);
$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
Expand Down Expand Up @@ -641,4 +635,48 @@ public static function queryFilterProvider(): \Iterator
],
];
}

public function testQueryWithQueryTexts()
{
$queryVector = new Vector([0.15, 0.25, 0.35]);
$queryTexts = ['search for this text'];

$queryResponse = new QueryItemsResponse(
ids: [['01234567-89ab-cdef-0123-456789abcdef']],
embeddings: [[[0.1, 0.2, 0.3]]],
metadatas: [[['title' => 'Doc 1']]],
documents: null,
data: null,
uris: null,
distances: [[0.123]]
);

$collection = $this->createMock(Collection::class);
$client = $this->createMock(Client::class);

$client->expects($this->once())
->method('getOrCreateCollection')
->with('test-collection')
->willReturn($collection);

$collection->expects($this->once())
->method('query')
->with(
[[0.15, 0.25, 0.35]], // queryEmbeddings
['search for this text'], // queryTexts
4, // nResults
null, // where
null, // whereDocument
null // include
)
->willReturn($queryResponse);

$store = new Store($client, 'test-collection');
$documents = iterator_to_array($store->query($queryVector, ['queryTexts' => $queryTexts]));

$this->assertCount(1, $documents);
$this->assertSame('01234567-89ab-cdef-0123-456789abcdef', (string) $documents[0]->id);
$this->assertSame([0.1, 0.2, 0.3], $documents[0]->vector->getData());
$this->assertSame(0.123, $documents[0]->score);
}
}
2 changes: 1 addition & 1 deletion src/store/src/Bridge/ChromaDb/composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
],
"require": {
"php": ">=8.2",
"codewithkyrian/chromadb-php": "^0.2.1|^0.3|^0.4",
"codewithkyrian/chromadb-php": "^1.0",
"symfony/ai-platform": "@dev",
"symfony/ai-store": "@dev",
"symfony/uid": "^7.3|^8.0"
Expand Down