Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/store/src/Bridge/MariaDb/Store.php
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,20 @@ public function add(VectorDocument ...$documents): void
*/
public function query(Vector $vector, array $options = []): array
{
$where = null;

$maxScore = $options['maxScore'] ?? null;
if ($maxScore) {
$where = \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(`%1$s`, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName);
}

if ($options['where'] ?? false) {
if ($where) {
$where .= ' AND ('.$options['where'].')';
} else {
$where = 'WHERE '.$options['where'];
}
}

$statement = $this->connection->prepare(
\sprintf(
Expand All @@ -156,12 +169,15 @@ public function query(Vector $vector, array $options = []): array
SQL,
$this->vectorFieldName,
$this->tableName,
null !== $maxScore ? \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(%1$s, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName) : '',
$where ?? '',
$options['limit'] ?? 5,
),
);

$params = ['embedding' => json_encode($vector->getData())];
$params = [
'embedding' => json_encode($vector->getData()),
...$options['params'] ?? [],
];

if (null !== $maxScore) {
$params['maxScore'] = $maxScore;
Expand Down
152 changes: 151 additions & 1 deletion src/store/tests/Bridge/MariaDb/StoreTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public function testQueryWithMaxScore()
$expectedQuery = <<<'SQL'
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
FROM embeddings_table
WHERE VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(:embedding)) <= :maxScore
WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) <= :maxScore
ORDER BY score ASC
LIMIT 5
SQL;
Expand Down Expand Up @@ -155,6 +155,151 @@ public function testQueryWithCustomLimit()
$this->assertCount(0, $results);
}

public function testQueryWithCustomWhereExpression()
{
$pdo = $this->createMock(\PDO::class);
$statement = $this->createMock(\PDOStatement::class);

$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');

$expectedQuery = <<<SQL
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
FROM embeddings_table
WHERE metadata->>'category' = 'products'
ORDER BY score
ASC LIMIT 5
SQL;

$pdo->expects($this->once())
->method('prepare')
->with($this->callback(function ($sql) use ($expectedQuery) {
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));

return true;
}))
->willReturn($statement);

$statement->expects($this->once())
->method('execute')
->with(['embedding' => '[0.1,0.2,0.3]']);

$statement->expects($this->once())
->method('fetchAll')
->with(\PDO::FETCH_ASSOC)
->willReturn([]);

$results = $store->query(new Vector([0.1, 0.2, 0.3]), ['where' => 'metadata->>\'category\' = \'products\'']);

$this->assertCount(0, $results);
}

public function testQueryWithCustomWhereExpressionAndMaxScore()
{
$pdo = $this->createMock(\PDO::class);
$statement = $this->createMock(\PDOStatement::class);

$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');

$expectedQuery = <<<SQL
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
FROM embeddings_table
WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) <= :maxScore
AND (metadata->>'active' = 'true')
ORDER BY score ASC
LIMIT 5
SQL;

$pdo->expects($this->once())
->method('prepare')
->with($this->callback(function ($sql) use ($expectedQuery) {
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));

return true;
}))
->willReturn($statement);

$statement->expects($this->once())
->method('execute')
->with([
'embedding' => '[0.1,0.2,0.3]',
'maxScore' => 0.5,
]);

$statement->expects($this->once())
->method('fetchAll')
->with(\PDO::FETCH_ASSOC)
->willReturn([]);

$results = $store->query(new Vector([0.1, 0.2, 0.3]), [
'maxScore' => 0.5,
'where' => 'metadata->>\'active\' = \'true\'',
]);

$this->assertCount(0, $results);
}

public function testQueryWithCustomWhereExpressionAndParams()
{
$pdo = $this->createMock(\PDO::class);
$statement = $this->createMock(\PDOStatement::class);

$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');

$expectedQuery = <<<SQL
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
FROM embeddings_table
WHERE metadata->>'crawlId' = :crawlId
AND id != :currentId
ORDER BY score
ASC LIMIT 5
SQL;

$pdo->expects($this->once())
->method('prepare')
->with($this->callback(function ($sql) use ($expectedQuery) {
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));

return true;
}))
->willReturn($statement);

$uuid = Uuid::v4();
$crawlId = '396af6fe-0dfd-47ed-b222-3dbcced3f38e';

$statement->expects($this->once())
->method('execute')
->with([
'embedding' => '[0.1,0.2,0.3]',
'crawlId' => $crawlId,
'currentId' => $uuid->toRfc4122(),
]);

$statement->expects($this->once())
->method('fetchAll')
->with(\PDO::FETCH_ASSOC)
->willReturn([
[
'id' => Uuid::v4()->toRfc4122(),
'embedding' => '[0.4,0.5,0.6]',
'metadata' => json_encode(['crawlId' => $crawlId, 'url' => 'https://example.com']),
'score' => 0.85,
],
]);

$results = $store->query(new Vector([0.1, 0.2, 0.3]), [
'where' => 'metadata->>\'crawlId\' = :crawlId AND id != :currentId',
'params' => [
'crawlId' => $crawlId,
'currentId' => $uuid->toRfc4122(),
],
]);

$this->assertCount(1, $results);
$this->assertSame(0.85, $results[0]->score);
$this->assertSame($crawlId, $results[0]->metadata['crawlId']);
$this->assertSame('https://example.com', $results[0]->metadata['url']);
}

public function testItCanDrop()
{
$pdo = $this->createMock(\PDO::class);
Expand All @@ -168,4 +313,9 @@ public function testItCanDrop()

$store->drop();
}

private function normalizeQuery(string $query): string
{
return trim(preg_replace('/\s+/', ' ', $query));
}
}
Loading