Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Contracts/VectorClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public function upsert(string $collection, array $points, bool $wait = true): Up
* @param array<string, mixed>|null $filter
* @return array<ScoredPoint>
*/
public function search(string $collection, array $vector, int $limit = 10, ?array $filter = null): array;
public function search(string $collection, array $vector, int $limit = 10, ?array $filter = null, ?float $scoreThreshold = null): array;

/**
* @param array<string, mixed>|null $filter
Expand Down
13 changes: 13 additions & 0 deletions src/Qdrant.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ public function __construct(
protected readonly QdrantConnector $connector,
) {}

public static function collection(string $name): QueryBuilder
{
/** @var VectorClient $client */
$client = app(VectorClient::class);

return new QueryBuilder($client, $name);
}

public function query(string $collection): QueryBuilder
{
return new QueryBuilder($this, $collection);
}

/**
* @param array<string, array<string, mixed>>|null $sparseVectors
*/
Expand Down
119 changes: 119 additions & 0 deletions src/QueryBuilder.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
<?php

declare(strict_types=1);

namespace TheShit\Vector;

use TheShit\Vector\Contracts\FilterBuilder;
use TheShit\Vector\Contracts\VectorClient;
use TheShit\Vector\Data\ScoredPoint;
use TheShit\Vector\Filters\QdrantFilter;

final class QueryBuilder
{
private QdrantFilter $filter;

private int $queryLimit = 10;

private ?float $scoreThreshold = null;

/** @var array<float>|null */
private ?array $vector = null;

public function __construct(
private readonly VectorClient $client,
private readonly string $collection,
) {
$this->filter = new QdrantFilter;
}

public function where(string $key, mixed $value): self
{
$this->filter->must($key, $value);

return $this;
}

/**
* @param array<mixed> $values
*/
public function whereIn(string $key, array $values): self
{
$this->filter->mustAny($key, $values);

return $this;
}

public function whereNot(string $key, mixed $value): self
{
$this->filter->mustNot($key, $value);

return $this;
}

public function whereRange(string $key, ?float $gte = null, ?float $lte = null, ?float $gt = null, ?float $lt = null): self
{
$this->filter->mustRange($key, $gte, $lte, $gt, $lt);

return $this;
}

/**
* @param array<float> $vector
*/
public function nearVector(array $vector): self
{
$this->vector = $vector;

return $this;
}

public function limit(int $limit): self
{
$this->queryLimit = $limit;

return $this;
}

public function minScore(float $score): self
{
$this->scoreThreshold = $score;

return $this;
}

/**
* @return array<ScoredPoint>
*/
public function search(): array
{
if ($this->vector === null) {
throw new \InvalidArgumentException('A vector is required for search. Call nearVector() before search().');
}

$filter = $this->filter->toArray() ?: null;

return $this->client->search(
collection: $this->collection,
vector: $this->vector,
limit: $this->queryLimit,
filter: $filter,
scoreThreshold: $this->scoreThreshold,
);
}

public function count(): int
{
$filter = $this->filter->toArray() ?: null;

return $this->client->count(
collection: $this->collection,
filter: $filter,
);
}

public function filter(): FilterBuilder
{
return $this->filter;
}
}
145 changes: 145 additions & 0 deletions tests/Feature/QueryBuilderIntegrationTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
<?php

declare(strict_types=1);

use Saloon\Http\Faking\MockClient;
use Saloon\Http\Faking\MockResponse;
use TheShit\Vector\Qdrant;
use TheShit\Vector\QdrantConnector;
use TheShit\Vector\QueryBuilder;
use TheShit\Vector\Requests\Points\SearchPointsRequest;
use TheShit\Vector\Tests\TestCase;

uses(TestCase::class);

function makeQdrant(MockClient $mock): Qdrant
{
$connector = new QdrantConnector('http://localhost:6333', 'test-key');
$connector->withMockClient($mock);

return new Qdrant($connector);
}

describe('Qdrant::query', function (): void {
it('returns a QueryBuilder instance', function (): void {
$client = makeQdrant(new MockClient([]));

expect($client->query('memories'))->toBeInstanceOf(QueryBuilder::class);
});

it('executes search through the builder', function (): void {
$mock = new MockClient([
SearchPointsRequest::class => MockResponse::make([
'result' => [
['id' => 'x', 'score' => 0.9, 'payload' => ['project' => 'lexi']],
],
'status' => 'ok',
]),
]);

$results = makeQdrant($mock)
->query('memories')
->where('project', 'lexi')
->nearVector([0.1, 0.2])
->limit(5)
->search();

expect($results)->toHaveCount(1)
->and($results[0]->id)->toBe('x')
->and($results[0]->score)->toBe(0.9);

$mock->assertSent(function (SearchPointsRequest $request): bool {
$body = invade($request)->defaultBody();

return $body['limit'] === 5
&& $body['filter'] === ['must' => [['key' => 'project', 'match' => ['value' => 'lexi']]]]
&& ! isset($body['score_threshold']);
});
});

it('passes score_threshold through minScore', function (): void {
$mock = new MockClient([
SearchPointsRequest::class => MockResponse::make([
'result' => [],
'status' => 'ok',
]),
]);

makeQdrant($mock)
->query('memories')
->nearVector([0.1])
->minScore(0.75)
->search();

$mock->assertSent(function (SearchPointsRequest $request): bool {
$body = invade($request)->defaultBody();

return $body['score_threshold'] === 0.75;
});
});
});

describe('Qdrant::collection', function (): void {
it('returns a QueryBuilder via static call', function (): void {
$builder = Qdrant::collection('memories');

expect($builder)->toBeInstanceOf(QueryBuilder::class);
});

it('resolves VectorClient from the container', function (): void {
$mock = new MockClient([
SearchPointsRequest::class => MockResponse::make([
'result' => [
['id' => 'm-1', 'score' => 0.88, 'payload' => []],
],
'status' => 'ok',
]),
]);

$connector = $this->app->make(QdrantConnector::class);
$connector->withMockClient($mock);

$results = Qdrant::collection('memories')
->nearVector([0.5])
->search();

expect($results)->toHaveCount(1)
->and($results[0]->id)->toBe('m-1');
});
});

describe('Qdrant::search with scoreThreshold', function (): void {
it('passes score_threshold to request', function (): void {
$mock = new MockClient([
SearchPointsRequest::class => MockResponse::make([
'result' => [],
'status' => 'ok',
]),
]);

makeQdrant($mock)->search('coll', [0.1], scoreThreshold: 0.5);

$mock->assertSent(function (SearchPointsRequest $request): bool {
$body = invade($request)->defaultBody();

return $body['score_threshold'] === 0.5;
});
});

it('omits score_threshold when null', function (): void {
$mock = new MockClient([
SearchPointsRequest::class => MockResponse::make([
'result' => [],
'status' => 'ok',
]),
]);

makeQdrant($mock)->search('coll', [0.1]);

$mock->assertSent(function (SearchPointsRequest $request): bool {
$body = invade($request)->defaultBody();

return ! isset($body['score_threshold']);
});
});
});
Loading
Loading