Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/ai-bundle/config/options.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
use Codewithkyrian\ChromaDB\Client as ChromaDbClient;
use MongoDB\Client as MongoDbClient;
use Probots\Pinecone\Client as PineconeClient;
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory;
use Symfony\AI\Platform\PlatformInterface;
use Symfony\AI\Store\StoreInterface;

Expand Down Expand Up @@ -59,6 +60,14 @@
->arrayNode('openai')
->children()
->scalarNode('api_key')->isRequired()->end()
->scalarNode('region')
->defaultNull()
->validate()
->ifNotInArray([null, PlatformFactory::REGION_EU, PlatformFactory::REGION_US])
->thenInvalid('The region must be either "EU" (https://eu.api.openai.com), "US" (https://us.api.openai.com) or null (https://api.openai.com)')
->end()
->info('The region for OpenAI API (EU, US, or null for default)')
->end()
->end()
->end()
->arrayNode('mistral')
Expand Down
1 change: 1 addition & 0 deletions src/ai-bundle/src/AiBundle.php
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
$platform['api_key'],
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
new Reference('ai.platform.contract.openai'),
$platform['region'] ?? null,
])
->addTag('ai.platform');

Expand Down
67 changes: 66 additions & 1 deletion src/ai-bundle/tests/DependencyInjection/AiBundleTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ public function testTokenUsageProcessorTags()
'ai' => [
'platform' => [
'openai' => [
'api_key' => 'test_key',
'api_key' => 'sk-test_key',
],
],
'agent' => [
Expand Down Expand Up @@ -489,6 +489,71 @@ public function testTokenUsageProcessorTags()
$this->assertTrue($foundTag, 'Token usage processor should have output tag with full agent ID');
}

public function testOpenAiPlatformWithDefaultRegion()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'openai' => [
'api_key' => 'sk-test-key',
],
],
],
]);

$this->assertTrue($container->hasDefinition('ai.platform.openai'));

$definition = $container->getDefinition('ai.platform.openai');
$arguments = $definition->getArguments();

$this->assertCount(4, $arguments);
$this->assertSame('sk-test-key', $arguments[0]);
$this->assertNull($arguments[3]); // region should be null by default
}

#[TestWith(['EU'])]
#[TestWith(['US'])]
#[TestWith([null])]
public function testOpenAiPlatformWithRegion(?string $region)
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'openai' => [
'api_key' => 'sk-test-key',
'region' => $region,
],
],
],
]);

$this->assertTrue($container->hasDefinition('ai.platform.openai'));

$definition = $container->getDefinition('ai.platform.openai');
$arguments = $definition->getArguments();

$this->assertCount(4, $arguments);
$this->assertSame('sk-test-key', $arguments[0]);
$this->assertSame($region, $arguments[3]);
}

public function testOpenAiPlatformWithInvalidRegion()
{
$this->expectException(InvalidConfigurationException::class);
$this->expectExceptionMessage('The region must be either "EU" (https://eu.api.openai.com), "US" (https://us.api.openai.com) or null (https://api.openai.com)');

$this->buildContainer([
'ai' => [
'platform' => [
'openai' => [
'api_key' => 'sk-test-key',
'region' => 'INVALID',
],
],
],
]);
}

private function buildContainer(array $configuration): ContainerBuilder
{
$container = new ContainerBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

use Symfony\AI\Platform\Bridge\LmStudio\Completions;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\HttpClientInterface;

/**
* @author André Lubian <lubiana123@gmail.com>
*/
final readonly class ModelClient implements PlatformResponseFactory
final readonly class ModelClient implements ModelClientInterface
{
private EventSourceHttpClient $httpClient;

Expand Down
4 changes: 2 additions & 2 deletions src/platform/src/Bridge/LmStudio/Embeddings/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

use Symfony\AI\Platform\Bridge\LmStudio\Embeddings;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\Contracts\HttpClient\HttpClientInterface;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
* @author André Lubian <lubiana123@gmail.com>
*/
final readonly class ModelClient implements PlatformResponseFactory
final readonly class ModelClient implements ModelClientInterface
{
public function __construct(
private HttpClientInterface $httpClient,
Expand Down
41 changes: 41 additions & 0 deletions src/platform/src/Bridge/OpenAi/AbstractModelClient.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Bridge\OpenAi;

use Symfony\AI\Platform\Exception\InvalidArgumentException;

/**
* @author Oskar Stark <oskar.stark@sensiolabs.de>
*/
abstract readonly class AbstractModelClient
{
protected static function getBaseUrl(?string $region): string
{
return match ($region) {
null => 'https://api.openai.com',
PlatformFactory::REGION_EU => 'https://eu.api.openai.com',
PlatformFactory::REGION_US => 'https://us.api.openai.com',
default => throw new InvalidArgumentException(\sprintf('Invalid region "%s". Valid options are: "%s", "%s", or null.', $region, PlatformFactory::REGION_EU, PlatformFactory::REGION_US)),
};
}

protected static function validateApiKey(string $apiKey): void
{
if ('' === $apiKey) {
throw new InvalidArgumentException('The API key must not be empty.');
}

if (!str_starts_with($apiKey, 'sk-')) {
throw new InvalidArgumentException('The API key must start with "sk-".');
}
}
}
14 changes: 5 additions & 9 deletions src/platform/src/Bridge/OpenAi/DallE/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

namespace Symfony\AI\Platform\Bridge\OpenAi\DallE;

use Symfony\AI\Platform\Bridge\OpenAi\AbstractModelClient;
use Symfony\AI\Platform\Bridge\OpenAi\DallE;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
Expand All @@ -23,18 +23,14 @@
*
* @author Denis Zunke <denis.zunke@gmail.com>
*/
final readonly class ModelClient implements ModelClientInterface
final readonly class ModelClient extends AbstractModelClient implements ModelClientInterface
{
public function __construct(
private HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
private ?string $region = null,
) {
if ('' === $apiKey) {
throw new InvalidArgumentException('The API key must not be empty.');
}
if (!str_starts_with($apiKey, 'sk-')) {
throw new InvalidArgumentException('The API key must start with "sk-".');
}
self::validateApiKey($apiKey);
}

public function supports(Model $model): bool
Expand All @@ -44,7 +40,7 @@ public function supports(Model $model): bool

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
{
return new RawHttpResult($this->httpClient->request('POST', 'https://api.openai.com/v1/images/generations', [
return new RawHttpResult($this->httpClient->request('POST', self::getBaseUrl($this->region).'/v1/images/generations', [
'auth_bearer' => $this->apiKey,
'json' => array_merge($options, [
'model' => $model->getName(),
Expand Down
16 changes: 6 additions & 10 deletions src/platform/src/Bridge/OpenAi/Embeddings/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@

namespace Symfony\AI\Platform\Bridge\OpenAi\Embeddings;

use Symfony\AI\Platform\Bridge\OpenAi\AbstractModelClient;
use Symfony\AI\Platform\Bridge\OpenAi\Embeddings;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\Contracts\HttpClient\HttpClientInterface;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*/
final readonly class ModelClient implements PlatformResponseFactory
final readonly class ModelClient extends AbstractModelClient implements ModelClientInterface
{
public function __construct(
private HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
private ?string $region = null,
) {
if ('' === $apiKey) {
throw new InvalidArgumentException('The API key must not be empty.');
}
if (!str_starts_with($apiKey, 'sk-')) {
throw new InvalidArgumentException('The API key must start with "sk-".');
}
self::validateApiKey($apiKey);
}

public function supports(Model $model): bool
Expand All @@ -42,7 +38,7 @@ public function supports(Model $model): bool

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
{
return new RawHttpResult($this->httpClient->request('POST', 'https://api.openai.com/v1/embeddings', [
return new RawHttpResult($this->httpClient->request('POST', self::getBaseUrl($this->region).'/v1/embeddings', [
'auth_bearer' => $this->apiKey,
'json' => array_merge($options, [
'model' => $model->getName(),
Expand Down
16 changes: 6 additions & 10 deletions src/platform/src/Bridge/OpenAi/Gpt/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,28 @@

namespace Symfony\AI\Platform\Bridge\OpenAi\Gpt;

use Symfony\AI\Platform\Bridge\OpenAi\AbstractModelClient;
use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\HttpClientInterface;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*/
final readonly class ModelClient implements PlatformResponseFactory
final readonly class ModelClient extends AbstractModelClient implements ModelClientInterface
{
private EventSourceHttpClient $httpClient;

public function __construct(
HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
private ?string $region = null,
) {
$this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
if ('' === $apiKey) {
throw new InvalidArgumentException('The API key must not be empty.');
}
if (!str_starts_with($apiKey, 'sk-')) {
throw new InvalidArgumentException('The API key must start with "sk-".');
}
self::validateApiKey($apiKey);
}

public function supports(Model $model): bool
Expand All @@ -46,7 +42,7 @@ public function supports(Model $model): bool

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
{
return new RawHttpResult($this->httpClient->request('POST', 'https://api.openai.com/v1/chat/completions', [
return new RawHttpResult($this->httpClient->request('POST', self::getBaseUrl($this->region).'/v1/chat/completions', [
'auth_bearer' => $this->apiKey,
'json' => array_merge($options, $payload),
]));
Expand Down
12 changes: 8 additions & 4 deletions src/platform/src/Bridge/OpenAi/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,23 @@
*/
final readonly class PlatformFactory
{
public const REGION_EU = 'EU';
public const REGION_US = 'US';

public static function create(
#[\SensitiveParameter] string $apiKey,
?HttpClientInterface $httpClient = null,
?Contract $contract = null,
?string $region = null,
): Platform {
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);

return new Platform(
[
new Gpt\ModelClient($httpClient, $apiKey),
new Embeddings\ModelClient($httpClient, $apiKey),
new DallE\ModelClient($httpClient, $apiKey),
new WhisperModelClient($httpClient, $apiKey),
new Gpt\ModelClient($httpClient, $apiKey, $region),
new Embeddings\ModelClient($httpClient, $apiKey, $region),
new DallE\ModelClient($httpClient, $apiKey, $region),
new WhisperModelClient($httpClient, $apiKey, $region),
],
[
new Gpt\ResultConverter(),
Expand Down
16 changes: 6 additions & 10 deletions src/platform/src/Bridge/OpenAi/Whisper/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@

namespace Symfony\AI\Platform\Bridge\OpenAi\Whisper;

use Symfony\AI\Platform\Bridge\OpenAi\AbstractModelClient;
use Symfony\AI\Platform\Bridge\OpenAi\Whisper;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface as BaseModelClient;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\Contracts\HttpClient\HttpClientInterface;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*/
final readonly class ModelClient implements BaseModelClient
final readonly class ModelClient extends AbstractModelClient implements ModelClientInterface
{
public function __construct(
private HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
private ?string $region = null,
) {
if ('' === $apiKey) {
throw new InvalidArgumentException('The API key must not be empty.');
}
if (!str_starts_with($apiKey, 'sk-')) {
throw new InvalidArgumentException('The API key must start with "sk-".');
}
self::validateApiKey($apiKey);
}

public function supports(Model $model): bool
Expand All @@ -46,7 +42,7 @@ public function request(Model $model, array|string $payload, array $options = []
$endpoint = Task::TRANSCRIPTION === $task ? 'transcriptions' : 'translations';
unset($options['task']);

return new RawHttpResult($this->httpClient->request('POST', \sprintf('https://api.openai.com/v1/audio/%s', $endpoint), [
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/v1/audio/%s', self::getBaseUrl($this->region), $endpoint), [
'auth_bearer' => $this->apiKey,
'headers' => ['Content-Type' => 'multipart/form-data'],
'body' => array_merge($options, $payload, ['model' => $model->getName()]),
Expand Down
Loading