diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index 753940c0f..e65074102 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -46,7 +46,6 @@ use Symfony\AI\Platform\Bridge\Mistral\ModelCatalog as MistralModelCatalog; use Symfony\AI\Platform\Bridge\Mistral\TokenOutputProcessor as MistralTokenOutputProcessor; use Symfony\AI\Platform\Bridge\Ollama\Contract\OllamaContract; -use Symfony\AI\Platform\Bridge\Ollama\ModelCatalog as OllamaModelCatalog; use Symfony\AI\Platform\Bridge\OpenAi\Contract\OpenAiContract; use Symfony\AI\Platform\Bridge\OpenAi\ModelCatalog as OpenAiModelCatalog; use Symfony\AI\Platform\Bridge\OpenAi\TokenOutputProcessor as OpenAiTokenOutputProcessor; @@ -67,6 +66,7 @@ use Symfony\AI\Store\Command\DropStoreCommand; use Symfony\AI\Store\Command\IndexCommand; use Symfony\AI\Store\Command\SetupStoreCommand; +use Symfony\Component\DependencyInjection\Reference; return static function (ContainerConfigurator $container): void { $container->services() @@ -99,7 +99,6 @@ ->set('ai.platform.model_catalog.huggingface', HuggingFaceModelCatalog::class) ->set('ai.platform.model_catalog.lmstudio', LmStudioModelCatalog::class) ->set('ai.platform.model_catalog.mistral', MistralModelCatalog::class) - ->set('ai.platform.model_catalog.ollama', OllamaModelCatalog::class) ->set('ai.platform.model_catalog.openai', OpenAiModelCatalog::class) ->set('ai.platform.model_catalog.openrouter', OpenRouterModelCatalog::class) ->set('ai.platform.model_catalog.perplexity', PerplexityModelCatalog::class) diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 4adcd4e5b..6c004a8ee 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -45,6 +45,7 @@ use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory; use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory; use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; +use Symfony\AI\Platform\Bridge\Ollama\OllamaCatalog; use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory; use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; @@ -442,7 +443,14 @@ private function processPlatformConfig(string $type, array $platform, ContainerB } if ('ollama' === $type) { - $platformId = 'ai.platform.ollama'; + $catalogDefinition = (new Definition(OllamaCatalog::class)) + ->setArguments([ + $platform['host_url'], + new Reference('http_client'), + ]); + + $container->setDefinition('ai.platform.model_catalog.ollama', $catalogDefinition); + $definition = (new Definition(Platform::class)) ->setFactory(OllamaPlatformFactory::class.'::create') ->setLazy(true) @@ -455,7 +463,7 @@ private function processPlatformConfig(string $type, array $platform, ContainerB ]) ->addTag('ai.platform', ['name' => 'ollama']); - $container->setDefinition($platformId, $definition); + $container->setDefinition('ai.platform.ollama', $definition); return; } diff --git a/src/platform/src/Bridge/Ollama/OllamaCatalog.php b/src/platform/src/Bridge/Ollama/OllamaCatalog.php new file mode 100644 index 000000000..43a57f76d --- /dev/null +++ b/src/platform/src/Bridge/Ollama/OllamaCatalog.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Ollama; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelCatalog\DynamicModelCatalog; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +final class OllamaCatalog extends DynamicModelCatalog +{ + public function __construct( + private readonly string $host, + private readonly HttpClientInterface $httpClient, + ) { + parent::__construct(); + } + + public function getModel(string $modelName): Model + { + $model = parent::getModel($modelName); + + $response = $this->httpClient->request('POST', \sprintf('%s/api/show', $this->host), [ + 'json' => [ + 'model' => $model->getName(), + ], + ]); + + $payload = $response->toArray(); + + if ([] === $payload['capabilities'] ?? []) { + throw new InvalidArgumentException('The model information could not be retrieved from the Ollama API. Your Ollama server might be too old. Try upgrade it.'); + } + + return new Ollama($model->getName(), $payload['capabilities'], $model->getOptions()); + } +} diff --git a/src/platform/src/Bridge/Ollama/OllamaClient.php b/src/platform/src/Bridge/Ollama/OllamaClient.php index f1220ed79..90e929ef8 100644 --- a/src/platform/src/Bridge/Ollama/OllamaClient.php +++ b/src/platform/src/Bridge/Ollama/OllamaClient.php @@ -35,21 +35,9 @@ public function supports(Model $model): bool public function request(Model $model, array|string $payload, array $options = []): RawHttpResult { - $response = $this->httpClient->request('POST', \sprintf('%s/api/show', $this->hostUrl), [ - 'json' => [ - 'model' => $model->getName(), - ], - ]); - - $capabilities = $response->toArray()['capabilities'] ?? null; - - if (null === $capabilities) { - throw new InvalidArgumentException('The model information could not be retrieved from the Ollama API. Your Ollama server might be too old. Try upgrade it.'); - } - return match (true) { - \in_array('completion', $capabilities, true) => $this->doCompletionRequest($payload, $options), - \in_array('embedding', $capabilities, true) => $this->doEmbeddingsRequest($model, $payload, $options), + \in_array('completion', $model->getCapabilities(), true) => $this->doCompletionRequest($payload, $options), + \in_array('embedding', $model->getCapabilities(), true) => $this->doEmbeddingsRequest($model, $payload, $options), default => throw new InvalidArgumentException(\sprintf('Unsupported model "%s": "%s".', $model::class, $model->getName())), }; } diff --git a/src/platform/src/Bridge/Ollama/PlatformFactory.php b/src/platform/src/Bridge/Ollama/PlatformFactory.php index dd801f39c..46dfd4c4f 100644 --- a/src/platform/src/Bridge/Ollama/PlatformFactory.php +++ b/src/platform/src/Bridge/Ollama/PlatformFactory.php @@ -26,7 +26,7 @@ final class PlatformFactory public static function create( string $hostUrl = 'http://localhost:11434', ?HttpClientInterface $httpClient = null, - ModelCatalogInterface $modelCatalog = new ModelCatalog(), + ?ModelCatalogInterface $modelCatalog = null, ?Contract $contract = null, ): Platform { $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); @@ -34,7 +34,7 @@ public static function create( return new Platform( [new OllamaClient($httpClient, $hostUrl)], [new OllamaResultConverter()], - $modelCatalog, + $modelCatalog ?? new OllamaCatalog($hostUrl, $httpClient), $contract ?? OllamaContract::create(), ); }