From 82e4110c1d743041dfe44718ffabd19c11554567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Pineau?= Date: Thu, 20 Nov 2025 17:25:27 +0100 Subject: [PATCH] [Store][MariaDB] Add support for custom WHERE clause --- src/store/src/Bridge/MariaDb/Store.php | 20 ++- src/store/tests/Bridge/MariaDb/StoreTest.php | 152 +++++++++++++++++- src/store/tests/Bridge/Postgres/StoreTest.php | 86 ++++++---- 3 files changed, 220 insertions(+), 38 deletions(-) diff --git a/src/store/src/Bridge/MariaDb/Store.php b/src/store/src/Bridge/MariaDb/Store.php index 9bc668121..aa2caf977 100644 --- a/src/store/src/Bridge/MariaDb/Store.php +++ b/src/store/src/Bridge/MariaDb/Store.php @@ -143,7 +143,20 @@ public function add(VectorDocument ...$documents): void */ public function query(Vector $vector, array $options = []): array { + $where = null; + $maxScore = $options['maxScore'] ?? null; + if ($maxScore) { + $where = \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(`%1$s`, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName); + } + + if ($options['where'] ?? false) { + if ($where) { + $where .= ' AND ('.$options['where'].')'; + } else { + $where = 'WHERE '.$options['where']; + } + } $statement = $this->connection->prepare( \sprintf( @@ -156,12 +169,15 @@ public function query(Vector $vector, array $options = []): array SQL, $this->vectorFieldName, $this->tableName, - null !== $maxScore ? \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(%1$s, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName) : '', + $where ?? '', $options['limit'] ?? 5, ), ); - $params = ['embedding' => json_encode($vector->getData())]; + $params = [ + 'embedding' => json_encode($vector->getData()), + ...$options['params'] ?? [], + ]; if (null !== $maxScore) { $params['maxScore'] = $maxScore; diff --git a/src/store/tests/Bridge/MariaDb/StoreTest.php b/src/store/tests/Bridge/MariaDb/StoreTest.php index e9c2c5507..afb5eeb2f 100644 --- a/src/store/tests/Bridge/MariaDb/StoreTest.php +++ b/src/store/tests/Bridge/MariaDb/StoreTest.php @@ -30,7 +30,7 @@ public function testQueryWithMaxScore() $expectedQuery = <<<'SQL' SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score FROM embeddings_table - WHERE VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(:embedding)) <= :maxScore + WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) <= :maxScore ORDER BY score ASC LIMIT 5 SQL; @@ -155,6 +155,151 @@ public function testQueryWithCustomLimit() $this->assertCount(0, $results); } + public function testQueryWithCustomWhereExpression() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding'); + + $expectedQuery = <<>'category' = 'products' + ORDER BY score + ASC LIMIT 5 + SQL; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with(['embedding' => '[0.1,0.2,0.3]']); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['where' => 'metadata->>\'category\' = \'products\'']); + + $this->assertCount(0, $results); + } + + public function testQueryWithCustomWhereExpressionAndMaxScore() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding'); + + $expectedQuery = <<>'active' = 'true') + ORDER BY score ASC + LIMIT 5 + SQL; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with([ + 'embedding' => '[0.1,0.2,0.3]', + 'maxScore' => 0.5, + ]); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [ + 'maxScore' => 0.5, + 'where' => 'metadata->>\'active\' = \'true\'', + ]); + + $this->assertCount(0, $results); + } + + public function testQueryWithCustomWhereExpressionAndParams() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding'); + + $expectedQuery = <<>'crawlId' = :crawlId + AND id != :currentId + ORDER BY score + ASC LIMIT 5 + SQL; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; + })) + ->willReturn($statement); + + $uuid = Uuid::v4(); + $crawlId = '396af6fe-0dfd-47ed-b222-3dbcced3f38e'; + + $statement->expects($this->once()) + ->method('execute') + ->with([ + 'embedding' => '[0.1,0.2,0.3]', + 'crawlId' => $crawlId, + 'currentId' => $uuid->toRfc4122(), + ]); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => Uuid::v4()->toRfc4122(), + 'embedding' => '[0.4,0.5,0.6]', + 'metadata' => json_encode(['crawlId' => $crawlId, 'url' => 'https://example.com']), + 'score' => 0.85, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [ + 'where' => 'metadata->>\'crawlId\' = :crawlId AND id != :currentId', + 'params' => [ + 'crawlId' => $crawlId, + 'currentId' => $uuid->toRfc4122(), + ], + ]); + + $this->assertCount(1, $results); + $this->assertSame(0.85, $results[0]->score); + $this->assertSame($crawlId, $results[0]->metadata['crawlId']); + $this->assertSame('https://example.com', $results[0]->metadata['url']); + } + public function testItCanDrop() { $pdo = $this->createMock(\PDO::class); @@ -168,4 +313,9 @@ public function testItCanDrop() $store->drop(); } + + private function normalizeQuery(string $query): string + { + return trim(preg_replace('/\s+/', ' ', $query)); + } } diff --git a/src/store/tests/Bridge/Postgres/StoreTest.php b/src/store/tests/Bridge/Postgres/StoreTest.php index c48ce5b00..0067a9704 100644 --- a/src/store/tests/Bridge/Postgres/StoreTest.php +++ b/src/store/tests/Bridge/Postgres/StoreTest.php @@ -30,14 +30,16 @@ public function testAddSingleDocument() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'INSERT INTO embeddings_table (id, metadata, embedding) + $expectedQuery = 'INSERT INTO embeddings_table (id, metadata, embedding) VALUES (:id, :metadata, :vector) ON CONFLICT (id) DO UPDATE SET metadata = EXCLUDED.metadata, embedding = EXCLUDED.embedding'; $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -102,7 +104,7 @@ public function testQueryWithoutMaxScore() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table ORDER BY score ASC @@ -110,8 +112,10 @@ public function testQueryWithoutMaxScore() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -149,7 +153,7 @@ public function testQueryChangedDistanceMethodWithoutMaxScore() $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score FROM embeddings_table ORDER BY score ASC @@ -157,8 +161,10 @@ public function testQueryChangedDistanceMethodWithoutMaxScore() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -196,7 +202,7 @@ public function testQueryWithMaxScore() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table WHERE (embedding <-> :embedding) <= :maxScore ORDER BY score ASC @@ -204,8 +210,10 @@ public function testQueryWithMaxScore() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -233,7 +241,7 @@ public function testQueryWithMaxScoreAndDifferentDistance() $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score FROM embeddings_table WHERE (embedding <=> :embedding) <= :maxScore ORDER BY score ASC @@ -241,8 +249,10 @@ public function testQueryWithMaxScoreAndDifferentDistance() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -270,7 +280,7 @@ public function testQueryWithCustomLimit() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table ORDER BY score ASC @@ -278,8 +288,10 @@ public function testQueryWithCustomLimit() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -304,7 +316,7 @@ public function testQueryWithCustomVectorFieldName() $store = new Store($pdo, 'embeddings_table', 'custom_vector'); - $expectedSql = 'SELECT id, custom_vector AS embedding, metadata, (custom_vector <-> :embedding) AS score + $expectedQuery = 'SELECT id, custom_vector AS embedding, metadata, (custom_vector <-> :embedding) AS score FROM embeddings_table ORDER BY score ASC @@ -312,8 +324,10 @@ public function testQueryWithCustomVectorFieldName() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -475,7 +489,7 @@ public function testQueryWithCustomWhereExpression() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table WHERE metadata->>\'category\' = \'products\' ORDER BY score ASC @@ -483,8 +497,10 @@ public function testQueryWithCustomWhereExpression() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -509,7 +525,7 @@ public function testQueryWithCustomWhereExpressionAndMaxScore() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table WHERE (embedding <-> :embedding) <= :maxScore AND (metadata->>\'active\' = \'true\') ORDER BY score ASC @@ -517,8 +533,10 @@ public function testQueryWithCustomWhereExpressionAndMaxScore() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -549,7 +567,7 @@ public function testQueryWithCustomWhereExpressionAndParams() $store = new Store($pdo, 'embeddings_table', 'embedding'); - $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + $expectedQuery = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score FROM embeddings_table WHERE metadata->>\'crawlId\' = :crawlId AND id != :currentId ORDER BY score ASC @@ -557,8 +575,10 @@ public function testQueryWithCustomWhereExpressionAndParams() $pdo->expects($this->once()) ->method('prepare') - ->with($this->callback(function ($sql) use ($expectedSql) { - return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + ->with($this->callback(function ($sql) use ($expectedQuery) { + $this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql)); + + return true; })) ->willReturn($statement); @@ -601,10 +621,6 @@ public function testQueryWithCustomWhereExpressionAndParams() private function normalizeQuery(string $query): string { - // Remove extra spaces, tabs and newlines - $normalized = preg_replace('/\s+/', ' ', $query); - - // Trim the result - return trim($normalized); + return trim(preg_replace('/\s+/', ' ', $query)); } }