Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/agent/src/StructuredOutput/AgentProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
use Symfony\AI\Agent\InputProcessorInterface;
use Symfony\AI\Agent\Output;
use Symfony\AI\Agent\OutputProcessorInterface;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\PlatformInterface;
use Symfony\AI\Platform\Result\ObjectResult;
use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor;
use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor;
Expand All @@ -39,6 +41,7 @@ final class AgentProcessor implements InputProcessorInterface, OutputProcessorIn
private string $outputStructure;

public function __construct(
private PlatformInterface $platform,
private readonly ResponseFormatFactoryInterface $responseFormatFactory = new ResponseFormatFactory(),
private ?SerializerInterface $serializer = null,
) {
Expand Down Expand Up @@ -77,6 +80,11 @@ public function processInput(Input $input): void
throw new InvalidArgumentException('Streamed responses are not supported for structured output.');
}

$modelObject = $this->platform->getModelCatalog()->getModel($input->getModel());
if (!\in_array(Capability::OUTPUT_STRUCTURED, $modelObject->getCapabilities(), true)) {
throw MissingModelSupportException::forStructuredOutput($modelObject->getName());
Comment on lines +83 to +85
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be enough for me

Suggested change
$modelObject = $this->platform->getModelCatalog()->getModel($input->getModel());
if (!\in_array(Capability::OUTPUT_STRUCTURED, $modelObject->getCapabilities(), true)) {
throw MissingModelSupportException::forStructuredOutput($modelObject->getName());
$model = $this->platform->getModelCatalog()->getModel($input->getModel());
if (!\in_array(Capability::OUTPUT_STRUCTURED, $model->getCapabilities(), true)) {
throw MissingModelSupportException::forStructuredOutput($model->getName());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the new pattern.

Is it ok to inject the platform every time we need the capabilities ; or would it be better to have the model instance directly in the input like it was before @chr-hertel ?

}

$options['response_format'] = $this->responseFormatFactory->create($options['output_structure']);

$this->outputStructure = $options['output_structure'];
Expand Down
150 changes: 99 additions & 51 deletions src/agent/tests/StructuredOutput/AgentProcessorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\TestCase;
use Symfony\AI\Agent\Exception\MissingModelSupportException;
use Symfony\AI\Agent\Input;
use Symfony\AI\Agent\Output;
use Symfony\AI\Agent\StructuredOutput\AgentProcessor;
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface;
use Symfony\AI\Fixtures\SomeStructure;
use Symfony\AI\Fixtures\StructuredOutput\MathReasoning;
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemAge;
Expand All @@ -25,18 +27,31 @@
use Symfony\AI\Fixtures\StructuredOutput\UnionType\HumanReadableTimeUnion;
use Symfony\AI\Fixtures\StructuredOutput\UnionType\UnionTypeDto;
use Symfony\AI\Fixtures\StructuredOutput\UnionType\UnixTimestampUnion;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\AI\Platform\Message\UserMessage;
use Symfony\AI\Platform\Metadata\Metadata;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\PlatformInterface;
use Symfony\AI\Platform\Result\ObjectResult;
use Symfony\AI\Platform\Result\TextResult;
use Symfony\Component\Serializer\SerializerInterface;

if (!class_exists(__NAMESPACE__.'\ConfigurableResponseFormatFactory')) {
class ConfigurableResponseFormatFactory implements ResponseFormatFactoryInterface {
public function __construct(private array $format = []) {}
public function create(string $structure): array { return $this->format; }
}
}
Comment on lines +41 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? there is already Symfony\AI\Agent\Tests\StructuredOutput\ConfigurableResponseFormatFactory


final class AgentProcessorTest extends TestCase
{
public function testProcessInputWithOutputStructure()
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
$input = new Input('gpt-4', new MessageBag(), ['output_structure' => 'SomeStructure']);
$platformMock = $this->createPlatformMock('gpt-4');
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory(['some' => 'format']));
$input = new Input('gpt-4', new MessageBag(), ['output_structure' => SomeStructure::class]);

$processor->processInput($input);

Expand All @@ -45,38 +60,75 @@ public function testProcessInputWithOutputStructure()

public function testProcessInputWithoutOutputStructure()
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory());
$platformMock = $this->createMock(PlatformInterface::class);
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory());
$input = new Input('gpt-4', new MessageBag());

$processor->processInput($input);

$this->assertSame([], $input->getOptions());
}

public function testProcessInputThrowsExceptionForMissingSupport()
{
$modelName = 'model-without-structured-output';

$modelMock = $this->createMock(Model::class);
$modelMock->method('getCapabilities')->willReturn([
Capability::INPUT_MESSAGES,
Capability::OUTPUT_TEXT,
]);
$modelMock->method('getName')->willReturn($modelName);

$modelCatalogMock = $this->createMock(ModelCatalogInterface::class);
$modelCatalogMock
->expects($this->once())
->method('getModel')
->with($modelName)
->willReturn($modelMock);

$platformMock = $this->createMock(PlatformInterface::class);
$platformMock
->expects($this->once())
->method('getModelCatalog')
->willReturn($modelCatalogMock);

$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory());
$messages = new MessageBag(new UserMessage(new \Symfony\AI\Platform\Message\Content\Text('Hello')));
$options = ['output_structure' => 'App\Dto\MyStructure'];
$input = new Input($modelName, $messages, $options);

$this->expectException(MissingModelSupportException::class);
$this->expectExceptionMessage(\sprintf('Model "%s" does not support "structured output".', $modelName));

$processor->processInput($input);
}

public function testProcessOutputWithResponseFormat()
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
$platformMock = $this->createPlatformMock('gpt-4');
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory(['some' => 'format']));

$options = ['output_structure' => SomeStructure::class];
$input = new Input('gpt-4', new MessageBag(), $options);
$processor->processInput($input);

$result = new TextResult('{"some": "data"}');

$output = new Output('gpt-4', $result, new MessageBag(), $input->getOptions());

$processor->processOutput($output);

$this->assertInstanceOf(ObjectResult::class, $output->getResult());
$this->assertInstanceOf(SomeStructure::class, $output->getResult()->getContent());
$resultContent = $output->getResult()->getContent();
$this->assertInstanceOf(SomeStructure::class, $resultContent);
$this->assertInstanceOf(Metadata::class, $output->getResult()->getMetadata());
$this->assertNull($output->getResult()->getRawResult());
$this->assertSame('data', $output->getResult()->getContent()->some);
$this->assertSame('data', $resultContent->some);
}

public function testProcessOutputWithComplexResponseFormat()
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
$platformMock = $this->createPlatformMock('gpt-4');
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory(['some' => 'format']));

$options = ['output_structure' => MathReasoning::class];
$input = new Input('gpt-4', new MessageBag(), $options);
Expand Down Expand Up @@ -112,15 +164,17 @@ public function testProcessOutputWithComplexResponseFormat()
JSON);

$output = new Output('gpt-4', $result, new MessageBag(), $input->getOptions());

$processor->processOutput($output);

$this->assertInstanceOf(ObjectResult::class, $output->getResult());
$this->assertInstanceOf(MathReasoning::class, $structure = $output->getResult()->getContent());
$structure = $output->getResult()->getContent();
$this->assertInstanceOf(MathReasoning::class, $structure);
$this->assertInstanceOf(Metadata::class, $output->getResult()->getMetadata());
$this->assertNull($output->getResult()->getRawResult());
$this->assertCount(5, $structure->steps);
$this->assertInstanceOf(Step::class, $structure->steps[0]);
$this->assertSame("We want to isolate the term with x. First, let's subtract 7 from both sides of the equation.", $structure->steps[0]->explanation);
$this->assertSame("8x + 7 - 7 = -23 - 7", $structure->steps[0]->output);
$this->assertInstanceOf(Step::class, $structure->steps[1]);
$this->assertInstanceOf(Step::class, $structure->steps[2]);
$this->assertInstanceOf(Step::class, $structure->steps[3]);
Expand All @@ -129,13 +183,11 @@ public function testProcessOutputWithComplexResponseFormat()
$this->assertSame('x = -3.75', $structure->finalAnswer);
}

/**
* @param class-string $expectedTimeStructure
*/
#[DataProvider('unionTimeTypeProvider')]
public function testProcessOutputWithUnionTypeResponseFormat(TextResult $result, string $expectedTimeStructure)
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
$platformMock = $this->createPlatformMock('gpt-4');
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory(['some' => 'format']));

$options = ['output_structure' => UnionTypeDto::class];
$input = new Input('gpt-4', new MessageBag(), $options);
Expand All @@ -154,21 +206,8 @@ public function testProcessOutputWithUnionTypeResponseFormat(TextResult $result,

public static function unionTimeTypeProvider(): array
{
$unixTimestampResult = new TextResult(<<<JSON
{
"time": {
"timestamp": 2212121
}
}
JSON);

$humanReadableResult = new TextResult(<<<JSON
{
"time": {
"readableTime": "2023-10-10T10:10:10+00:00"
}
}
JSON);
$unixTimestampResult = new TextResult('{"time": {"timestamp": 2212121}}');
$humanReadableResult = new TextResult('{"time": {"readableTime": "2023-10-10T10:10:10+00:00"}}');

return [
[$unixTimestampResult, UnixTimestampUnion::class],
Expand All @@ -178,63 +217,72 @@ public static function unionTimeTypeProvider(): array

public function testProcessOutputWithCorrectPolymorphicTypesResponseFormat()
{
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
$platformMock = $this->createPlatformMock('gpt-4');
$processor = new AgentProcessor($platformMock, new ConfigurableResponseFormatFactory(['some' => 'format']));

$options = ['output_structure' => ListOfPolymorphicTypesDto::class];
$input = new Input('gpt-4', new MessageBag(), $options);
$processor->processInput($input);

$result = new TextResult(<<<JSON
{
"items": [
{
"type": "name",
"name": "John Doe"
},
{
"type": "age",
"age": 24
}
]
}
{"items": [{"type": "name", "name": "John Doe"}, {"type": "age", "age": 24}]}
JSON);

$output = new Output('gpt-4', $result, new MessageBag(), $input->getOptions());

$processor->processOutput($output);

$this->assertInstanceOf(ObjectResult::class, $output->getResult());

/** @var ListOfPolymorphicTypesDto $structure */
$structure = $output->getResult()->getContent();
$this->assertInstanceOf(ListOfPolymorphicTypesDto::class, $structure);

$this->assertCount(2, $structure->items);

$nameItem = $structure->items[0];
$ageItem = $structure->items[1];

$this->assertInstanceOf(ListItemName::class, $nameItem);
$this->assertInstanceOf(ListItemAge::class, $ageItem);

$this->assertSame('John Doe', $nameItem->name);
$this->assertSame(24, $ageItem->age);

$this->assertSame('name', $nameItem->type);
$this->assertSame('age', $ageItem->type);
}

public function testProcessOutputWithoutResponseFormat()
{
$platformMock = $this->createMock(PlatformInterface::class);
$resultFormatFactory = new ConfigurableResponseFormatFactory();
$serializer = self::createMock(SerializerInterface::class);
$processor = new AgentProcessor($resultFormatFactory, $serializer);
$serializer = $this->createMock(SerializerInterface::class);
$processor = new AgentProcessor($platformMock, $resultFormatFactory, $serializer);

$result = new TextResult('');
$output = new Output('gpt4', $result, new MessageBag());

$processor->processOutput($output);

$this->assertSame($result, $output->getResult());
}

private function createPlatformMock(string $modelName): PlatformInterface
{
$modelMock = $this->createMock(Model::class);
$modelMock->method('getCapabilities')->willReturn([
Capability::INPUT_MESSAGES,
Capability::OUTPUT_TEXT,
Capability::OUTPUT_STRUCTURED,
]);
$modelMock->method('getName')->willReturn($modelName);

$modelCatalogMock = $this->createMock(ModelCatalogInterface::class);
$modelCatalogMock
->method('getModel')
->with($modelName)
->willReturn($modelMock);

$platformMock = $this->createMock(PlatformInterface::class);
$platformMock
->method('getModelCatalog')
->willReturn($modelCatalogMock);

return $platformMock;
}
}
3 changes: 2 additions & 1 deletion src/ai-bundle/config/services.php
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@
->alias(ResponseFormatFactoryInterface::class, 'ai.agent.response_format_factory')
->set('ai.agent.structured_output_processor', StructureOutputProcessor::class)
->args([
service('ai.platform'),
service('ai.agent.response_format_factory'),
service('serializer'),
service('serializer')->nullOnInvalid(),
])

// tools
Expand Down
5 changes: 0 additions & 5 deletions src/platform/src/Bridge/Gemini/ModelCatalog.php
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_AUDIO,
Capability::INPUT_PDF,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand All @@ -81,7 +80,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_AUDIO,
Capability::INPUT_PDF,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand All @@ -93,7 +91,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_AUDIO,
Capability::INPUT_PDF,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand All @@ -105,7 +102,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_AUDIO,
Capability::INPUT_PDF,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand All @@ -117,7 +113,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_AUDIO,
Capability::INPUT_PDF,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand Down
2 changes: 0 additions & 2 deletions src/platform/src/Bridge/VertexAi/ModelCatalog.php
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_PDF,
Capability::OUTPUT_TEXT,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand All @@ -92,7 +91,6 @@ public function __construct(array $additionalModels = [])
Capability::INPUT_PDF,
Capability::OUTPUT_TEXT,
Capability::OUTPUT_STREAMING,
Capability::OUTPUT_STRUCTURED,
Capability::TOOL_CALLING,
],
],
Expand Down
Loading