diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 64c5d3a32..908180fe6 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -47,6 +47,7 @@ use Symfony\AI\Chat\ChatInterface; use Symfony\AI\Chat\MessageStoreInterface; use Symfony\AI\Platform\Bridge\Albert\PlatformFactory as AlbertPlatformFactory; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory; use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\Cartesia\PlatformFactory as CartesiaPlatformFactory; @@ -54,18 +55,23 @@ use Symfony\AI\Platform\Bridge\DeepSeek\PlatformFactory as DeepSeekPlatformFactory; use Symfony\AI\Platform\Bridge\DockerModelRunner\PlatformFactory as DockerModelRunnerPlatformFactory; use Symfony\AI\Platform\Bridge\ElevenLabs\PlatformFactory as ElevenLabsPlatformFactory; +use Symfony\AI\Platform\Bridge\Gemini\Gemini; use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory; use Symfony\AI\Platform\Bridge\HuggingFace\PlatformFactory as HuggingFacePlatformFactory; use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory; +use Symfony\AI\Platform\Bridge\Mistral\Mistral; use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; +use Symfony\AI\Platform\Bridge\Ollama\Ollama; use Symfony\AI\Platform\Bridge\Ollama\OllamaApiCatalog; use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; use Symfony\AI\Platform\Bridge\Perplexity\PlatformFactory as PerplexityPlatformFactory; use Symfony\AI\Platform\Bridge\Scaleway\PlatformFactory as ScalewayPlatformFactory; use Symfony\AI\Platform\Bridge\VertexAi\PlatformFactory as VertexAiPlatformFactory; use Symfony\AI\Platform\Bridge\Voyage\PlatformFactory as VoyagePlatformFactory; +use Symfony\AI\Platform\Capability; use Symfony\AI\Platform\Exception\RuntimeException; use Symfony\AI\Platform\Message\Content\File; use Symfony\AI\Platform\ModelClientInterface; @@ -143,6 +149,12 @@ public function loadExtension(array $config, ContainerConfigurator $container, C foreach ($config['platform'] ?? [] as $type => $platform) { $this->processPlatformConfig($type, $platform, $builder); } + + // Process model configuration and pass to ModelCatalog services + foreach ($config['model'] ?? [] as $platformName => $models) { + $this->processModelConfig($platformName, $models, $builder); + } + $platforms = array_keys($builder->findTaggedServiceIds('ai.platform')); if (1 === \count($platforms)) { $builder->setAlias(PlatformInterface::class, reset($platforms)); @@ -1799,4 +1811,79 @@ private static function normalizeAgentServiceId(string $agentName): string { return str_starts_with($agentName, 'ai.agent.') ? $agentName : 'ai.agent.'.$agentName; } + + /** + * Process model configuration and pass it to ModelCatalog services. + * + * @param array}> $models + */ + private function processModelConfig(string $platformName, array $models, ContainerBuilder $builder): void + { + $modelCatalogId = 'ai.platform.model_catalog.'.$platformName; + + // Handle special cases for platform name mapping + if ('vertexai' === $platformName) { + $modelCatalogId = 'ai.platform.model_catalog.vertexai.gemini'; + } + + if (!$builder->hasDefinition($modelCatalogId)) { + return; + } + + $modelClass = $this->getModelClassForPlatform($platformName); + if (null === $modelClass) { + return; + } + + $additionalModels = []; + foreach ($models as $modelName => $modelConfig) { + $capabilities = []; + foreach ($modelConfig['capabilities'] ?? [] as $capability) { + // Capabilities may already be enum instances or strings + if ($capability instanceof Capability) { + $capabilities[] = $capability; + } elseif (\is_string($capability)) { + try { + $capabilities[] = Capability::from($capability); + } catch (\ValueError) { + // Skip invalid capability strings + continue; + } + } + } + + if ([] === $capabilities) { + continue; + } + + $additionalModels[$modelName] = [ + 'class' => $modelClass, + 'capabilities' => $capabilities, + ]; + } + + if ([] === $additionalModels) { + return; + } + + $definition = $builder->getDefinition($modelCatalogId); + $definition->setArguments([$additionalModels]); + } + + /** + * Get the model class for a given platform. + * + * @return class-string|null + */ + private function getModelClassForPlatform(string $platformName): ?string + { + return match ($platformName) { + 'anthropic' => Claude::class, + 'openai' => Gpt::class, + 'gemini' => Gemini::class, + 'mistral' => Mistral::class, + 'ollama' => Ollama::class, + default => null, + }; + } }