Skip to content

feat: add Google Gemini tool support (#331) #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/google/toolcall.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

use Symfony\AI\Agent\Agent;
use Symfony\AI\Agent\Toolbox\AgentProcessor;
use Symfony\AI\Agent\Toolbox\Tool\Clock;
use Symfony\AI\Agent\Toolbox\Toolbox;
use Symfony\AI\Platform\Bridge\Google\Gemini;
use Symfony\AI\Platform\Bridge\Google\PlatformFactory;
use Symfony\AI\Platform\Message\Message;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\Component\Dotenv\Dotenv;

require_once dirname(__DIR__, 2).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env');

if (empty($_ENV['GOOGLE_API_KEY'])) {
echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL;
exit(1);
}

$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
$llm = new Gemini(Gemini::GEMINI_2_FLASH);

$toolbox = Toolbox::create(new Clock());
$processor = new AgentProcessor($toolbox);
$chain = new Agent($platform, $llm, [$processor], [$processor]);

$messages = new MessageBag(Message::ofUser('What time is it?'));
$response = $chain->call($messages);

echo $response->getContent().\PHP_EOL;
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
use Symfony\AI\Platform\Message\AssistantMessage;
use Symfony\AI\Platform\Model;
use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface;
use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*/
final class AssistantMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface
final class AssistantMessageNormalizer extends ModelContractNormalizer
{
use NormalizerAwareTrait;

protected function supportedDataClass(): string
{
return AssistantMessage::class;
Expand All @@ -42,8 +38,23 @@ protected function supportsModel(Model $model): bool
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
return [
['text' => $data->content],
];
$normalized = [];

if (isset($data->content)) {
$normalized['text'] = $data->content;
}

if (isset($data->toolCalls[0])) {
$normalized['functionCall'] = [
'id' => $data->toolCalls[0]->id,
'name' => $data->toolCalls[0]->name,
];

if ($data->toolCalls[0]->arguments) {
$normalized['functionCall']['args'] = $data->toolCalls[0]->arguments;
}
}

return [$normalized];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Bridge\Google\Contract;

use Symfony\AI\Platform\Bridge\Google\Gemini;
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
use Symfony\AI\Platform\Message\ToolCallMessage;
use Symfony\AI\Platform\Model;

/**
* @author Valtteri R <valtzu@gmail.com>
*/
final class ToolCallMessageNormalizer extends ModelContractNormalizer
{
protected function supportedDataClass(): string
{
return ToolCallMessage::class;
}

protected function supportsModel(Model $model): bool
{
return $model instanceof Gemini;
}

/**
* @param ToolCallMessage $data
*
* @return array{
* functionResponse: array{
* id: string,
* name: string,
* response: array<int|string, mixed>
* }
* }[]
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
$responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content;

return [[
'functionResponse' => array_filter([
'id' => $data->toolCall->id,
'name' => $data->toolCall->name,
'response' => \is_array($responseContent) ? $responseContent : [
'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses.
],
]),
]];
}
}
63 changes: 63 additions & 0 deletions src/platform/src/Bridge/Google/Contract/ToolNormalizer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Bridge\Google\Contract;

use Symfony\AI\Platform\Bridge\Google\Gemini;
use Symfony\AI\Platform\Contract\JsonSchema\Factory;
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Tool\Tool;

/**
* @author Valtteri R <valtzu@gmail.com>
*
* @phpstan-import-type JsonSchema from Factory
*/
final class ToolNormalizer extends ModelContractNormalizer
{
protected function supportedDataClass(): string
{
return Tool::class;
}

protected function supportsModel(Model $model): bool
{
return $model instanceof Gemini;
}

/**
* @param Tool $data
*
* @return array{
* functionDeclarations: array{
* name: string,
* description: string,
* parameters: JsonSchema|array{type: 'object'}
* }[]
* }
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
$parameters = $data->parameters;
unset($parameters['additionalProperties']);

return [
'functionDeclarations' => [
[
'description' => $data->description,
'name' => $data->name,
'parameters' => $parameters,
],
],
];
}
}
1 change: 1 addition & 0 deletions src/platform/src/Bridge/Google/Gemini.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options =
Capability::INPUT_MESSAGES,
Capability::INPUT_IMAGE,
Capability::OUTPUT_STREAMING,
Capability::TOOL_CALLING,
];

parent::__construct($name, $capabilities, $options);
Expand Down
85 changes: 81 additions & 4 deletions src/platform/src/Bridge/Google/ModelHandler.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Response\Choice;
use Symfony\AI\Platform\Response\ChoiceResponse;
use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse;
use Symfony\AI\Platform\Response\StreamResponse;
use Symfony\AI\Platform\Response\TextResponse;
use Symfony\AI\Platform\Response\ToolCall;
use Symfony\AI\Platform\Response\ToolCallResponse;
use Symfony\AI\Platform\ResponseConverterInterface;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
Expand Down Expand Up @@ -59,6 +63,12 @@ public function request(Model $model, array|string $payload, array $options = []

$generationConfig = ['generationConfig' => $options];
unset($generationConfig['generationConfig']['stream']);
unset($generationConfig['generationConfig']['tools']);

if (isset($options['tools'])) {
$generationConfig['tools'] = $options['tools'];
unset($options['tools']);
}

return $this->httpClient->request('POST', $url, [
'headers' => [
Expand All @@ -83,11 +93,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe

$data = $response->toArray();

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
if (!isset($data['candidates'][0]['content']['parts'][0])) {
throw new RuntimeException('Response does not contain any content');
}

return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
/** @var Choice[] $choices */
$choices = array_map($this->convertChoice(...), $data['candidates']);

if (1 !== \count($choices)) {
return new ChoiceResponse(...$choices);
}

if ($choices[0]->hasToolCall()) {
return new ToolCallResponse(...$choices[0]->getToolCalls());
}

return new TextResponse($choices[0]->getContent());
}

private function convertStream(ResponseInterface $response): \Generator
Expand Down Expand Up @@ -121,12 +142,68 @@ private function convertStream(ResponseInterface $response): \Generator
throw new RuntimeException('Failed to decode JSON response', 0, $e);
}

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
/** @var Choice[] $choices */
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);

if (!$choices) {
continue;
}

yield $data['candidates'][0]['content']['parts'][0]['text'];
if (1 !== \count($choices)) {
yield new ChoiceResponse(...$choices);
continue;
}

if ($choices[0]->hasToolCall()) {
yield new ToolCallResponse(...$choices[0]->getToolCalls());
}

if ($choices[0]->hasContent()) {
yield $choices[0]->getContent();
}
}
}
}

/**
* @param array{
* finishReason?: string,
* content: array{
* parts: array{
* functionCall?: array{
* id: string,
* name: string,
* args: mixed[]
* },
* text?: string
* }[]
* }
* } $choice
*/
private function convertChoice(array $choice): Choice
{
$contentPart = $choice['content']['parts'][0] ?? [];

if (isset($contentPart['functionCall'])) {
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
}

if (isset($contentPart['text'])) {
return new Choice($contentPart['text']);
}

throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason']));
}

/**
* @param array{
* id: string,
* name: string,
* args: mixed[]
* } $toolCall
*/
private function convertToolCall(array $toolCall): ToolCall
{
return new ToolCall($toolCall['id'], $toolCall['name'], $toolCall['args']);
}
}
4 changes: 4 additions & 0 deletions src/platform/src/Bridge/Google/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer;
use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer;
use Symfony\AI\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer;
use Symfony\AI\Platform\Bridge\Google\Contract\ToolNormalizer;
use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer;
use Symfony\AI\Platform\Contract;
use Symfony\AI\Platform\Platform;
Expand All @@ -35,6 +37,8 @@ public static function create(
return new Platform([$responseHandler], [$responseHandler], Contract::create(
new AssistantMessageNormalizer(),
new MessageBagNormalizer(),
new ToolNormalizer(),
new ToolCallMessageNormalizer(),
new UserMessageNormalizer(),
));
}
Expand Down
Loading
Loading