From a8b4fba9af434c6b2e3df683051c77bfb53032cf Mon Sep 17 00:00:00 2001 From: Denis Zunke Date: Thu, 24 Jul 2025 13:38:36 +0200 Subject: [PATCH] [Store][Postgres] allow store initialization with utilized distance --- src/store/src/Bridge/Postgres/Distance.php | 37 ++++++++ src/store/src/Bridge/Postgres/Store.php | 37 +++++--- src/store/tests/Bridge/Postgres/StoreTest.php | 85 +++++++++++++++++++ 3 files changed, 146 insertions(+), 13 deletions(-) create mode 100644 src/store/src/Bridge/Postgres/Distance.php diff --git a/src/store/src/Bridge/Postgres/Distance.php b/src/store/src/Bridge/Postgres/Distance.php new file mode 100644 index 00000000..70ee7540 --- /dev/null +++ b/src/store/src/Bridge/Postgres/Distance.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres; + +use OskarStark\Enum\Trait\Comparable; + +/** + * @author Denis Zunke + */ +enum Distance: string +{ + use Comparable; + + case Cosine = 'cosine'; + case InnerProduct = 'inner_product'; + case L1 = 'l1'; + case L2 = 'l2'; + + public function getComparisonSign(): string + { + return match ($this) { + self::Cosine => '<=>', + self::InnerProduct => '<#>', + self::L1 => '<+>', + self::L2 => '<->', + }; + } +} diff --git a/src/store/src/Bridge/Postgres/Store.php b/src/store/src/Bridge/Postgres/Store.php index c9581162..50319f6b 100644 --- a/src/store/src/Bridge/Postgres/Store.php +++ b/src/store/src/Bridge/Postgres/Store.php @@ -34,23 +34,32 @@ public function __construct( private \PDO $connection, private string $tableName, private string $vectorFieldName = 'embedding', + private Distance $distance = Distance::L2, ) { } - public static function fromPdo(\PDO $connection, string $tableName, string $vectorFieldName = 'embedding'): self - { - return new self($connection, $tableName, $vectorFieldName); + public static function fromPdo( + \PDO $connection, + string $tableName, + string $vectorFieldName = 'embedding', + Distance $distance = Distance::L2, + ): self { + return new self($connection, $tableName, $vectorFieldName, $distance); } - public static function fromDbal(Connection $connection, string $tableName, string $vectorFieldName = 'embedding'): self - { + public static function fromDbal( + Connection $connection, + string $tableName, + string $vectorFieldName = 'embedding', + Distance $distance = Distance::L2, + ): self { $pdo = $connection->getNativeConnection(); if (!$pdo instanceof \PDO) { throw new InvalidArgumentException('Only DBAL connections using PDO driver are supported.'); } - return self::fromPdo($pdo, $tableName, $vectorFieldName); + return self::fromPdo($pdo, $tableName, $vectorFieldName, $distance); } public function add(VectorDocument ...$documents): void @@ -84,16 +93,18 @@ public function add(VectorDocument ...$documents): void */ public function query(Vector $vector, array $options = [], ?float $minScore = null): array { - $sql = \sprintf( - 'SELECT id, %s AS embedding, metadata, (%s <-> :embedding) AS score - FROM %s - %s - ORDER BY score ASC - LIMIT %d', + $sql = \sprintf(<<vectorFieldName, $this->vectorFieldName, + $this->distance->getComparisonSign(), $this->tableName, - null !== $minScore ? "WHERE ({$this->vectorFieldName} <-> :embedding) >= :minScore" : '', + null !== $minScore ? "WHERE ({$this->vectorFieldName} {$this->distance->getComparisonSign()} :embedding) >= :minScore" : '', $options['limit'] ?? 5, ); $statement = $this->connection->prepare($sql); diff --git a/src/store/tests/Bridge/Postgres/StoreTest.php b/src/store/tests/Bridge/Postgres/StoreTest.php index 976411fd..ede7f2bf 100644 --- a/src/store/tests/Bridge/Postgres/StoreTest.php +++ b/src/store/tests/Bridge/Postgres/StoreTest.php @@ -15,6 +15,7 @@ use PHPUnit\Framework\Attributes\CoversClass; use PHPUnit\Framework\TestCase; use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Bridge\Postgres\Distance; use Symfony\AI\Store\Bridge\Postgres\Store; use Symfony\AI\Store\Document\Metadata; use Symfony\AI\Store\Document\VectorDocument; @@ -152,6 +153,53 @@ public function testQueryWithoutMinScore() $this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy()); } + public function testQueryChangedDistanceMethodWithoutMinScore() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); + + $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score + FROM embeddings_table + + ORDER BY score ASC + LIMIT 5'; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedSql) { + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + })) + ->willReturn($statement); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('execute') + ->with(['embedding' => '[0.1,0.2,0.3]']); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['title' => 'Test Document']), + 'score' => 0.95, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + $this->assertCount(1, $results); + $this->assertInstanceOf(VectorDocument::class, $results[0]); + $this->assertEquals($uuid, $results[0]->id); + $this->assertSame(0.95, $results[0]->score); + $this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy()); + } + public function testQueryWithMinScore() { $pdo = $this->createMock(\PDO::class); @@ -189,6 +237,43 @@ public function testQueryWithMinScore() $this->assertCount(0, $results); } + public function testQueryWithMinScoreAndDifferentDistance() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); + + $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score + FROM embeddings_table + WHERE (embedding <=> :embedding) >= :minScore + ORDER BY score ASC + LIMIT 5'; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedSql) { + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with([ + 'embedding' => '[0.1,0.2,0.3]', + 'minScore' => 0.8, + ]); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [], 0.8); + + $this->assertCount(0, $results); + } + public function testQueryWithCustomLimit() { $pdo = $this->createMock(\PDO::class);