Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/ai-bundle/src/AiBundle.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,41 @@
use Symfony\AI\Chat\ChatInterface;
use Symfony\AI\Chat\MessageStoreInterface;
use Symfony\AI\Platform\Bridge\Albert\PlatformFactory as AlbertPlatformFactory;
use Symfony\AI\Platform\Bridge\Anthropic\Claude;
use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory;
use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory;
use Symfony\AI\Platform\Bridge\Cartesia\Cartesia;
use Symfony\AI\Platform\Bridge\Cartesia\PlatformFactory as CartesiaPlatformFactory;
use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory;
use Symfony\AI\Platform\Bridge\DeepSeek\DeepSeek;
use Symfony\AI\Platform\Bridge\DeepSeek\PlatformFactory as DeepSeekPlatformFactory;
use Symfony\AI\Platform\Bridge\DockerModelRunner\PlatformFactory as DockerModelRunnerPlatformFactory;
use Symfony\AI\Platform\Bridge\ElevenLabs\ElevenLabs;
use Symfony\AI\Platform\Bridge\ElevenLabs\PlatformFactory as ElevenLabsPlatformFactory;
use Symfony\AI\Platform\Bridge\Gemini\Gemini;
use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory;
use Symfony\AI\Platform\Bridge\HuggingFace\PlatformFactory as HuggingFacePlatformFactory;
use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory;
use Symfony\AI\Platform\Bridge\Meta\Llama;
use Symfony\AI\Platform\Bridge\Mistral\Mistral;
use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory;
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
use Symfony\AI\Platform\Bridge\Ollama\OllamaApiCatalog;
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory;
use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory;
use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory;
use Symfony\AI\Platform\Bridge\Perplexity\Perplexity;
use Symfony\AI\Platform\Bridge\Perplexity\PlatformFactory as PerplexityPlatformFactory;
use Symfony\AI\Platform\Bridge\Scaleway\PlatformFactory as ScalewayPlatformFactory;
use Symfony\AI\Platform\Bridge\Scaleway\Scaleway;
use Symfony\AI\Platform\Bridge\VertexAi\PlatformFactory as VertexAiPlatformFactory;
use Symfony\AI\Platform\Bridge\Voyage\PlatformFactory as VoyagePlatformFactory;
use Symfony\AI\Platform\Bridge\Voyage\Voyage;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Message\Content\File;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Platform;
use Symfony\AI\Platform\PlatformInterface;
Expand Down Expand Up @@ -292,6 +306,11 @@ public function loadExtension(array $config, ContainerConfigurator $container, C
$builder->removeDefinition('ai.data_collector');
$builder->removeDefinition('ai.traceable_toolbox');
}

// Process model configuration and pass to ModelCatalog services
foreach ($config['model'] ?? [] as $platformName => $models) {
$this->processModelConfig($platformName, $models, $builder);
}
}

/**
Expand Down Expand Up @@ -1799,4 +1818,83 @@ private static function normalizeAgentServiceId(string $agentName): string
{
return str_starts_with($agentName, 'ai.agent.') ? $agentName : 'ai.agent.'.$agentName;
}

/**
* @param array<string, mixed> $models
*/
private function processModelConfig(string $platformName, array $models, ContainerBuilder $builder): void
{
$modelCatalogServiceId = $this->getModelCatalogServiceId($platformName);

if (!$builder->hasDefinition($modelCatalogServiceId)) {
return;
}

$modelCatalogDefinition = $builder->getDefinition($modelCatalogServiceId);
$additionalModels = [];

foreach ($models as $modelName => $modelConfig) {
$modelClass = $this->getModelClassForPlatform($platformName);

if (null === $modelClass) {
continue;
}

$capabilities = [];
foreach ($modelConfig['capabilities'] as $capability) {
if ($capability instanceof Capability) {
$capabilities[] = $capability;
} else {
$capabilities[] = Capability::from($capability);
}
}

$additionalModels[$modelName] = [
'class' => $modelClass,
'capabilities' => $capabilities,
];
}

$modelCatalogDefinition->setArgument(0, $additionalModels);
}

private function getModelCatalogServiceId(string $platformName): string
{
if ('vertexai' === $platformName) {
return 'ai.platform.model_catalog.vertexai.gemini';
}

if ('eleven_labs' === $platformName) {
return 'ai.platform.model_catalog.elevenlabs';
}

return 'ai.platform.model_catalog.'.$platformName;
}

private function getModelClassForPlatform(string $platformName): ?string
{
return match ($platformName) {
'anthropic' => Claude::class,
'openai' => Gpt::class,
'gemini' => Gemini::class,
'mistral' => Mistral::class,
'ollama' => Ollama::class,
'deepseek' => DeepSeek::class,
'perplexity' => Perplexity::class,
'cartesia' => Cartesia::class,
'voyage' => Voyage::class,
'scaleway' => Scaleway::class,
'meta' => Llama::class,
'vertexai' => Gemini::class,
'eleven_labs' => ElevenLabs::class,
'cerebras' => Model::class,
'openrouter' => Model::class,
'dockermodelrunner' => Model::class,
'aimlapi' => Model::class,
'replicate' => Llama::class,
'bedrock' => Model::class,
'albert' => Gpt::class,
default => null,
};
}
}
200 changes: 200 additions & 0 deletions src/ai-bundle/tests/DependencyInjection/AiBundleTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -3322,6 +3322,206 @@ public function testSurrealDbMessageStoreIsConfiguredWithNamespacedUser()
$this->assertTrue($surrealDbMessageStoreDefinition->hasTag('ai.message_store'));
}

#[TestDox('Model configuration is processed and passed to ModelCatalog services')]
public function testModelConfigurationIsProcessed()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'anthropic' => [
'api_key' => 'test-key',
],
'openai' => [
'api_key' => 'test-key',
],
],
'model' => [
'anthropic' => [
'custom-claude-model' => [
'capabilities' => [
'input-messages',
'output-text',
'tool-calling',
],
],
],
'openai' => [
'custom-gpt-model' => [
'capabilities' => [
'input-messages',
'output-text',
'output-streaming',
],
],
],
],
],
]);

$anthropicCatalog = $container->getDefinition('ai.platform.model_catalog.anthropic');
$anthropicModels = $anthropicCatalog->getArgument(0);
$this->assertIsArray($anthropicModels);
$this->assertArrayHasKey('custom-claude-model', $anthropicModels);
$this->assertSame('Symfony\AI\Platform\Bridge\Anthropic\Claude', $anthropicModels['custom-claude-model']['class']);
$this->assertCount(3, $anthropicModels['custom-claude-model']['capabilities']);

$openaiCatalog = $container->getDefinition('ai.platform.model_catalog.openai');
$openaiModels = $openaiCatalog->getArgument(0);
$this->assertIsArray($openaiModels);
$this->assertArrayHasKey('custom-gpt-model', $openaiModels);
$this->assertSame('Symfony\AI\Platform\Bridge\OpenAi\Gpt', $openaiModels['custom-gpt-model']['class']);
$this->assertCount(3, $openaiModels['custom-gpt-model']['capabilities']);
}

#[TestDox('Model configuration for unsupported platforms is gracefully skipped')]
public function testModelConfigurationForUnsupportedPlatformIsSkipped()
{
$container = $this->buildContainer([
'ai' => [
'model' => [
'unsupported_platform' => [
'some-model' => [
'capabilities' => ['input-messages'],
],
],
],
],
]);

$this->assertFalse($container->hasDefinition('ai.platform.model_catalog.unsupported_platform'));
}

#[TestDox('Model configuration for vertexai uses correct ModelCatalog service ID')]
public function testModelConfigurationForVertexAi()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'vertexai' => [
'location' => 'us-central1',
'project_id' => 'test-project',
],
],
'model' => [
'vertexai' => [
'custom-gemini-model' => [
'capabilities' => [
'input-messages',
'output-text',
],
],
],
],
],
]);

$vertexaiCatalog = $container->getDefinition('ai.platform.model_catalog.vertexai.gemini');
$vertexaiModels = $vertexaiCatalog->getArgument(0);
$this->assertIsArray($vertexaiModels);
$this->assertArrayHasKey('custom-gemini-model', $vertexaiModels);
$this->assertSame('Symfony\AI\Platform\Bridge\Gemini\Gemini', $vertexaiModels['custom-gemini-model']['class']);
}

#[TestDox('Model configuration for eleven_labs uses correct ModelCatalog service ID')]
public function testModelConfigurationForElevenLabs()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'eleven_labs' => [
'api_key' => 'test-key',
'host' => 'https://api.elevenlabs.io/v1',
],
],
'model' => [
'eleven_labs' => [
'custom-elevenlabs-model' => [
'capabilities' => [
'input-text',
'output-audio',
'text-to-speech',
],
],
],
],
],
]);

$elevenlabsCatalog = $container->getDefinition('ai.platform.model_catalog.elevenlabs');
$elevenlabsModels = $elevenlabsCatalog->getArgument(0);
$this->assertIsArray($elevenlabsModels);
$this->assertArrayHasKey('custom-elevenlabs-model', $elevenlabsModels);
$this->assertSame('Symfony\AI\Platform\Bridge\ElevenLabs\ElevenLabs', $elevenlabsModels['custom-elevenlabs-model']['class']);
$this->assertCount(3, $elevenlabsModels['custom-elevenlabs-model']['capabilities']);
}

#[TestDox('Model configuration for newly added platforms (dockermodelrunner, aimlapi, replicate, albert) is processed correctly')]
public function testModelConfigurationForNewlyAddedPlatforms()
{
$container = $this->buildContainer([
'ai' => [
'model' => [
'dockermodelrunner' => [
'custom-docker-model' => [
'capabilities' => [
'input-messages',
'output-text',
],
],
],
'aimlapi' => [
'custom-aimlapi-model' => [
'capabilities' => [
'input-messages',
'output-text',
],
],
],
'replicate' => [
'custom-replicate-model' => [
'capabilities' => [
'input-messages',
'output-text',
],
],
],
'albert' => [
'custom-albert-model' => [
'capabilities' => [
'input-messages',
'output-text',
],
],
],
],
],
]);

$dockermodelrunnerCatalog = $container->getDefinition('ai.platform.model_catalog.dockermodelrunner');
$dockermodelrunnerModels = $dockermodelrunnerCatalog->getArgument(0);
$this->assertIsArray($dockermodelrunnerModels);
$this->assertArrayHasKey('custom-docker-model', $dockermodelrunnerModels);
$this->assertSame('Symfony\AI\Platform\Model', $dockermodelrunnerModels['custom-docker-model']['class']);

$aimlapiCatalog = $container->getDefinition('ai.platform.model_catalog.aimlapi');
$aimlapiModels = $aimlapiCatalog->getArgument(0);
$this->assertIsArray($aimlapiModels);
$this->assertArrayHasKey('custom-aimlapi-model', $aimlapiModels);
$this->assertSame('Symfony\AI\Platform\Model', $aimlapiModels['custom-aimlapi-model']['class']);

$replicateCatalog = $container->getDefinition('ai.platform.model_catalog.replicate');
$replicateModels = $replicateCatalog->getArgument(0);
$this->assertIsArray($replicateModels);
$this->assertArrayHasKey('custom-replicate-model', $replicateModels);
$this->assertSame('Symfony\AI\Platform\Bridge\Meta\Llama', $replicateModels['custom-replicate-model']['class']);

$albertCatalog = $container->getDefinition('ai.platform.model_catalog.albert');
$albertModels = $albertCatalog->getArgument(0);
$this->assertIsArray($albertModels);
$this->assertArrayHasKey('custom-albert-model', $albertModels);
$this->assertSame('Symfony\AI\Platform\Bridge\OpenAi\Gpt', $albertModels['custom-albert-model']['class']);
}

private function buildContainer(array $configuration): ContainerBuilder
{
$container = new ContainerBuilder();
Expand Down
Loading