Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/platform/src/Bridge/Albert/EmbeddingsModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public function request(Model $model, array|string $payload, array $options = []
{
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/embeddings', $this->baseUrl), [
'auth_bearer' => $this->apiKey,
'json' => \is_array($payload) ? array_merge($payload, $options) : $payload,
'json' => array_merge($options, [
'model' => $model->getName(),
'input' => $payload,
]),
]));
}
}
102 changes: 43 additions & 59 deletions src/platform/tests/Bridge/Albert/EmbeddingsModelClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

namespace Symfony\AI\Platform\Tests\Bridge\Albert;

use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\TestCase;
use Symfony\AI\Platform\Bridge\Albert\EmbeddingsModelClient;
use Symfony\AI\Platform\Bridge\OpenAi\Embeddings;
use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\Component\HttpClient\MockHttpClient;
use Symfony\Component\HttpClient\Response\JsonMockResponse;
use Symfony\Component\HttpClient\Response\MockResponse;
use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse;

final class EmbeddingsModelClientTest extends TestCase
{
Expand All @@ -29,7 +30,7 @@ public function testSupportsEmbeddingsModel()
'https://albert.example.com/'
);

$embeddingsModel = new Embeddings('text-embedding-ada-002');
$embeddingsModel = new Embeddings('embedding-small');
$this->assertTrue($client->supports($embeddingsModel));
}

Expand All @@ -45,66 +46,49 @@ public function testDoesNotSupportNonEmbeddingsModel()
$this->assertFalse($client->supports($gptModel));
}

#[DataProvider('providePayloadToJson')]
public function testRequestSendsCorrectHttpRequest(array|string $payload, array $options, array|string $expectedJson)
public function testItIsExecutingTheCorrectRequest()
{
$capturedRequest = null;
$httpClient = new MockHttpClient(function ($method, $url, $options) use (&$capturedRequest) {
$capturedRequest = ['method' => $method, 'url' => $url, 'options' => $options];

return new JsonMockResponse(['data' => []]);
});

$client = new EmbeddingsModelClient(
$httpClient,
'test-api-key',
'https://albert.example.com/v1'
);
$resultCallback = static function (string $method, string $url, array $options): HttpResponse {
self::assertSame('POST', $method);
self::assertSame('https://albert.example.com/v1/embeddings', $url);
self::assertSame('Authorization: Bearer api-key', $options['normalized_headers']['authorization'][0]);
self::assertSame('{"model":"embedding-small","input":"test text"}', $options['body']);

return new MockResponse();
};
$httpClient = new MockHttpClient([$resultCallback]);
$modelClient = new EmbeddingsModelClient($httpClient, 'api-key', 'https://albert.example.com/v1');
$modelClient->request(new Embeddings('embedding-small'), 'test text', []);
}

$model = new Embeddings('text-embedding-ada-002');
$result = $client->request($model, $payload, $options);

$this->assertNotNull($capturedRequest);
$this->assertSame('POST', $capturedRequest['method']);
$this->assertSame('https://albert.example.com/v1/embeddings', $capturedRequest['url']);
$this->assertArrayHasKey('normalized_headers', $capturedRequest['options']);
$this->assertArrayHasKey('authorization', $capturedRequest['options']['normalized_headers']);
$this->assertStringContainsString('Bearer test-api-key', (string) $capturedRequest['options']['normalized_headers']['authorization'][0]);

// Check JSON body - it might be in 'body' after processing
if (isset($capturedRequest['options']['body'])) {
$actualJson = json_decode($capturedRequest['options']['body'], true);
$this->assertEquals($expectedJson, $actualJson);
} else {
$this->assertSame($expectedJson, $capturedRequest['options']['json']);
}
public function testItIsExecutingTheCorrectRequestWithCustomOptions()
{
$resultCallback = static function (string $method, string $url, array $options): HttpResponse {
self::assertSame('POST', $method);
self::assertSame('https://albert.example.com/v1/embeddings', $url);
self::assertSame('Authorization: Bearer api-key', $options['normalized_headers']['authorization'][0]);
self::assertSame('{"dimensions":256,"model":"embedding-small","input":"test text"}', $options['body']);

return new MockResponse();
};
$httpClient = new MockHttpClient([$resultCallback]);
$modelClient = new EmbeddingsModelClient($httpClient, 'api-key', 'https://albert.example.com/v1');
$modelClient->request(new Embeddings('embedding-small'), 'test text', ['dimensions' => 256]);
}

public static function providePayloadToJson(): iterable
public function testItIsExecutingTheCorrectRequestWithArrayInput()
{
yield 'with array payload and no options' => [
['input' => 'test text', 'model' => 'text-embedding-ada-002'],
[],
['input' => 'test text', 'model' => 'text-embedding-ada-002'],
];

yield 'with string payload and no options' => [
'test text',
[],
'test text',
];

yield 'with array payload and options' => [
['input' => 'test text', 'model' => 'text-embedding-ada-002'],
['dimensions' => 1536],
['dimensions' => 1536, 'input' => 'test text', 'model' => 'text-embedding-ada-002'],
];

yield 'options override payload values' => [
['input' => 'test text', 'model' => 'text-embedding-ada-002'],
['model' => 'text-embedding-3-small'],
['model' => 'text-embedding-3-small', 'input' => 'test text'],
];
$resultCallback = static function (string $method, string $url, array $options): HttpResponse {
self::assertSame('POST', $method);
self::assertSame('https://albert.example.com/v1/embeddings', $url);
self::assertSame('Authorization: Bearer api-key', $options['normalized_headers']['authorization'][0]);
self::assertSame('{"model":"embedding-small","input":["text1","text2","text3"]}', $options['body']);

return new MockResponse();
};
$httpClient = new MockHttpClient([$resultCallback]);
$modelClient = new EmbeddingsModelClient($httpClient, 'api-key', 'https://albert.example.com/v1');
$modelClient->request(new Embeddings('embedding-small'), ['text1', 'text2', 'text3'], []);
}

public function testRequestHandlesBaseUrlWithoutTrailingSlash()
Expand All @@ -122,7 +106,7 @@ public function testRequestHandlesBaseUrlWithoutTrailingSlash()
'https://albert.example.com/v1'
);

$model = new Embeddings('text-embedding-ada-002');
$model = new Embeddings('embedding-small');
$client->request($model, ['input' => 'test']);

$this->assertSame('https://albert.example.com/v1/embeddings', $capturedUrl);
Expand All @@ -143,7 +127,7 @@ public function testRequestHandlesBaseUrlWithTrailingSlash()
'https://albert.example.com/v1'
);

$model = new Embeddings('text-embedding-ada-002');
$model = new Embeddings('embedding-small');
$client->request($model, ['input' => 'test']);

$this->assertSame('https://albert.example.com/v1/embeddings', $capturedUrl);
Expand Down