Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions fixtures/StructuredOutput/MathReasoningWithAttributes.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Fixtures\StructuredOutput;

use Symfony\Component\Serializer\Attribute\Ignore;
use Symfony\Component\Serializer\Attribute\SerializedName;

final class MathReasoningWithAttributes
{
/**
* @param Step[] $steps
*/
public function __construct(
public array $steps,
#[SerializedName('foo')]
public string $finalAnswer,
#[Ignore]
public float $result,
) {
}
}
4 changes: 3 additions & 1 deletion src/ai-bundle/config/services.php
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
use Symfony\AI\Platform\Contract;
use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser;
use Symfony\AI\Platform\Contract\JsonSchema\Factory as SchemaFactory;
use Symfony\AI\Platform\Serializer\StructuredOutputSerializer;
use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber;
use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactory;
use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactoryInterface;
Expand Down Expand Up @@ -122,10 +123,11 @@
service('type_info.resolver')->nullOnInvalid(),
])
->alias(ResponseFormatFactoryInterface::class, 'ai.platform.response_format_factory')
->set('ai.platform.structured_output_serializer', StructuredOutputSerializer::class)
->set('ai.platform.structured_output_subscriber', PlatformSubscriber::class)
->args([
service('ai.agent.response_format_factory'),
service('serializer'),
service('ai.platform.structured_output_serializer'),
])
->tag('kernel.event_subscriber')

Expand Down
1 change: 0 additions & 1 deletion src/platform/src/Contract/JsonSchema/Factory.php
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ private function findDiscriminatorMapping(string $className): ?array
* @see https://github.com/symfony/ai/pull/585#issuecomment-3303631346
*/
$reflectionProperty = new \ReflectionProperty($result, 'mapping');
$reflectionProperty->setAccessible(true);

return $reflectionProperty->getValue($result);
}
Expand Down
52 changes: 52 additions & 0 deletions src/platform/src/Serializer/StructuredOutputSerializer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Serializer;

use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor;
use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor;
use Symfony\Component\PropertyInfo\PropertyInfoExtractor;
use Symfony\Component\Serializer\Encoder\JsonEncoder;
use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata;
use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory;
use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader;
use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer;
use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer;
use Symfony\Component\Serializer\Normalizer\ObjectNormalizer;
use Symfony\Component\Serializer\Serializer;

class StructuredOutputSerializer extends Serializer
{
/*
* Custom serializer made to deserialize StructuredOutput.
*
* Since field name are generated by the `Symfony\AI\Platform\Contract\JsonSchema\Factory`
* without using the serializer (and the serializer metadata/attributes), we have to ignore them
* again when deserializing the data by not passing `classMetadataFactory` to ObjectNormalizer.
*/
public function __construct()
{
$classMetadataFactory = new ClassMetadataFactory(new AttributeLoader());
$discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory);
$propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]);

$normalizers = [
new BackedEnumNormalizer(),
new ObjectNormalizer(
propertyTypeExtractor: $propertyInfo,
classDiscriminatorResolver: $discriminator,
),
new ArrayDenormalizer(),
];

parent::__construct($normalizers, [new JsonEncoder()]);
}
}
36 changes: 5 additions & 31 deletions src/platform/src/StructuredOutput/PlatformSubscriber.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,8 @@
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Exception\MissingModelSupportException;
use Symfony\AI\Platform\Result\DeferredResult;
use Symfony\AI\Platform\Serializer\StructuredOutputSerializer;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor;
use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor;
use Symfony\Component\PropertyInfo\PropertyInfoExtractor;
use Symfony\Component\Serializer\Encoder\JsonEncoder;
use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata;
use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory;
use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader;
use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer;
use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer;
use Symfony\Component\Serializer\Normalizer\ObjectNormalizer;
use Symfony\Component\Serializer\Serializer;
use Symfony\Component\Serializer\SerializerInterface;

/**
Expand All @@ -40,29 +30,13 @@ final class PlatformSubscriber implements EventSubscriberInterface

private string $outputType;

private SerializerInterface $serializer;

public function __construct(
private readonly ResponseFormatFactoryInterface $responseFormatFactory = new ResponseFormatFactory(),
private ?SerializerInterface $serializer = null,
?SerializerInterface $serializer = null,
) {
if (null !== $this->serializer) {
return;
}

$classMetadataFactory = new ClassMetadataFactory(new AttributeLoader());
$discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory);
$propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]);

$normalizers = [
new BackedEnumNormalizer(),
new ObjectNormalizer(
classMetadataFactory: $classMetadataFactory,
propertyTypeExtractor: $propertyInfo,
classDiscriminatorResolver: $discriminator,
),
new ArrayDenormalizer(),
];

$this->serializer = new Serializer($normalizers, [new JsonEncoder()]);
$this->serializer = $serializer ?? new StructuredOutputSerializer();
}

public static function getSubscribedEvents(): array
Expand Down
23 changes: 17 additions & 6 deletions src/platform/tests/StructuredOutput/PlatformSubscriberTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use PHPUnit\Framework\TestCase;
use Symfony\AI\Fixtures\SomeStructure;
use Symfony\AI\Fixtures\StructuredOutput\MathReasoning;
use Symfony\AI\Fixtures\StructuredOutput\MathReasoningWithAttributes;
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemAge;
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemName;
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListOfPolymorphicTypesDto;
Expand All @@ -35,7 +36,6 @@
use Symfony\AI\Platform\Result\TextResult;
use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber;
use Symfony\AI\Platform\Test\PlainConverter;
use Symfony\Component\Serializer\SerializerInterface;

final class PlatformSubscriberTest extends TestCase
{
Expand Down Expand Up @@ -95,12 +95,16 @@ public function testProcessOutputWithResponseFormat()
$this->assertSame('data', $deferredResult->asObject()->some);
}

public function testProcessOutputWithComplexResponseFormat()
/**
* @param class-string $class
*/
#[DataProvider('complexFormatDataProvider')]
public function testProcessOutputWithComplexResponseFormat(string $class)
{
$processor = new PlatformSubscriber(new ConfigurableResponseFormatFactory(['some' => 'format']));

$model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]);
$options = ['response_format' => MathReasoning::class];
$options = ['response_format' => $class];
$invocationEvent = new InvocationEvent($model, new MessageBag(), $options);
$processor->processInput($invocationEvent);

Expand Down Expand Up @@ -139,7 +143,7 @@ public function testProcessOutputWithComplexResponseFormat()

$deferredResult = $resultEvent->getDeferredResult();
$this->assertInstanceOf(ObjectResult::class, $result = $deferredResult->getResult());
$this->assertInstanceOf(MathReasoning::class, $structure = $deferredResult->asObject());
$this->assertInstanceOf($class, $structure = $deferredResult->asObject());
$this->assertInstanceOf(Metadata::class, $result->getMetadata());
$this->assertCount(5, $structure->steps);
$this->assertInstanceOf(Step::class, $structure->steps[0]);
Expand All @@ -151,6 +155,14 @@ public function testProcessOutputWithComplexResponseFormat()
$this->assertSame(-3.75, $structure->result);
}

public static function complexFormatDataProvider(): iterable
{
return [
[MathReasoning::class],
[MathReasoningWithAttributes::class],
];
}

/**
* @param class-string $expectedTimeStructure
*/
Expand Down Expand Up @@ -254,8 +266,7 @@ public function testProcessOutputWithCorrectPolymorphicTypesResponseFormat()
public function testProcessOutputWithoutResponseFormat()
{
$resultFormatFactory = new ConfigurableResponseFormatFactory();
$serializer = self::createMock(SerializerInterface::class);
$processor = new PlatformSubscriber($resultFormatFactory, $serializer);
$processor = new PlatformSubscriber($resultFormatFactory);

$converter = new PlainConverter($result = new TextResult('{"some": "data"}'));
$deferred = new DeferredResult($converter, new InMemoryRawResult());
Expand Down