Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions src/platform/tests/Bridge/HuggingFace/ModelClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
use Symfony\AI\Platform\Message\UserMessage;
use Symfony\AI\Platform\Model;
use Symfony\Component\HttpClient\MockHttpClient;
use Symfony\Component\HttpClient\Response\MockResponse;

#[CoversClass(ModelClient::class)]
#[Small]
Expand All @@ -35,16 +36,23 @@ final class ModelClientTest extends TestCase
#[DataProvider('urlTestCases')]
public function testGetUrlForDifferentInputsAndTasks(?string $task, string $expectedUrl)
{
$reflection = new \ReflectionClass(ModelClient::class);
$getUrlMethod = $reflection->getMethod('getUrl');
$response = new MockResponse('{"result": "test"}', [
'http_code' => 200,
]);

$httpClient = new MockHttpClient(function (string $method, string $url) use ($expectedUrl, $response): MockResponse {
$this->assertSame('POST', $method);
$this->assertSame($expectedUrl, $url);

return $response;
});

$model = new Model('test-model');
$httpClient = new MockHttpClient();
$modelClient = new ModelClient($httpClient, 'test-provider', 'test-api-key');

$actualUrl = $getUrlMethod->invoke($modelClient, $model, $task);

$this->assertEquals($expectedUrl, $actualUrl);
// Make a request to trigger URL generation
$options = $task ? ['task' => $task] : [];
$modelClient->request($model, 'test input', $options);
}

public static function urlTestCases(): \Iterator
Expand Down Expand Up @@ -76,37 +84,56 @@ public static function urlTestCases(): \Iterator
#[DataProvider('payloadTestCases')]
public function testGetPayloadForDifferentInputsAndTasks(object|array|string $input, array $options, array $expectedKeys, array $expectedValues = [])
{
$response = new MockResponse('{"result": "test"}');
$httpClient = new MockHttpClient($response);

$model = new Model('test-model');
$modelClient = new ModelClient($httpClient, 'test-provider', 'test-api-key');

// Contract handling first
$contract = Contract::create(
new FileNormalizer(),
new MessageBagNormalizer()
);

$payload = $contract->createRequestPayload(new Model('test-model'), $input);

$reflection = new \ReflectionClass(ModelClient::class);
$getPayloadMethod = $reflection->getMethod('getPayload');
$payload = $contract->createRequestPayload($model, $input);

$httpClient = new MockHttpClient();
$modelClient = new ModelClient($httpClient, 'test-provider', 'test-api-key');
// Make a request to trigger payload generation
$modelClient->request($model, $payload, $options);

$actual = $getPayloadMethod->invoke($modelClient, $payload, $options);
// Get the request options that were sent
$requestOptions = $response->getRequestOptions();

// Check that expected keys exist
// Check that expected keys exist in the transformed structure
foreach ($expectedKeys as $key) {
$this->assertArrayHasKey($key, $actual);
if ('json' === $key) {
// JSON gets transformed to body in HTTP client
$this->assertArrayHasKey('body', $requestOptions);
} elseif ('headers' === $key) {
$this->assertArrayHasKey('headers', $requestOptions);
}
}

// Check expected values if specified
foreach ($expectedValues as $path => $value) {
$keys = explode('.', $path);
$current = $actual;
foreach ($keys as $key) {
$this->assertArrayHasKey($key, $current);
$current = $current[$key];
}

$this->assertEquals($value, $current);
if ('headers' === $keys[0] && 'Content-Type' === $keys[1]) {
// Check Content-Type header in the normalized structure
$this->assertContains('Content-Type: application/json', $requestOptions['headers']);
} elseif ('json' === $keys[0]) {
// JSON content is in the body, need to decode
$body = json_decode($requestOptions['body'], true);
$current = $body;

// Navigate through the remaining keys
for ($i = 1; $i < \count($keys); ++$i) {
$this->assertArrayHasKey($keys[$i], $current);
$current = $current[$keys[$i]];
}

$this->assertEquals($value, $current);
}
}
}

Expand Down