diff --git a/fixtures/StructuredOutput/MathReasoningWithAttributes.php b/fixtures/StructuredOutput/MathReasoningWithAttributes.php new file mode 100644 index 000000000..bd552c7ca --- /dev/null +++ b/fixtures/StructuredOutput/MathReasoningWithAttributes.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\StructuredOutput; + +use Symfony\Component\Serializer\Attribute\Ignore; +use Symfony\Component\Serializer\Attribute\SerializedName; + +final class MathReasoningWithAttributes +{ + /** + * @param Step[] $steps + */ + public function __construct( + public array $steps, + #[SerializedName('foo')] + public string $finalAnswer, + #[Ignore] + public float $result, + ) { + } +} diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index a1f65ba9b..1baacccc9 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -62,6 +62,7 @@ use Symfony\AI\Platform\Contract; use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; use Symfony\AI\Platform\Contract\JsonSchema\Factory as SchemaFactory; +use Symfony\AI\Platform\Serializer\StructuredOutputSerializer; use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber; use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactory; use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactoryInterface; @@ -122,10 +123,11 @@ service('type_info.resolver')->nullOnInvalid(), ]) ->alias(ResponseFormatFactoryInterface::class, 'ai.platform.response_format_factory') + ->set('ai.platform.structured_output_serializer', StructuredOutputSerializer::class) ->set('ai.platform.structured_output_subscriber', PlatformSubscriber::class) ->args([ service('ai.agent.response_format_factory'), - service('serializer'), + service('ai.platform.structured_output_serializer'), ]) ->tag('kernel.event_subscriber') diff --git a/src/platform/src/Contract/JsonSchema/Factory.php b/src/platform/src/Contract/JsonSchema/Factory.php index 43f5cba73..776537e4f 100644 --- a/src/platform/src/Contract/JsonSchema/Factory.php +++ b/src/platform/src/Contract/JsonSchema/Factory.php @@ -300,7 +300,6 @@ private function findDiscriminatorMapping(string $className): ?array * @see https://github.com/symfony/ai/pull/585#issuecomment-3303631346 */ $reflectionProperty = new \ReflectionProperty($result, 'mapping'); - $reflectionProperty->setAccessible(true); return $reflectionProperty->getValue($result); } diff --git a/src/platform/src/Serializer/StructuredOutputSerializer.php b/src/platform/src/Serializer/StructuredOutputSerializer.php new file mode 100644 index 000000000..a5ea566c3 --- /dev/null +++ b/src/platform/src/Serializer/StructuredOutputSerializer.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Serializer; + +use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor; +use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor; +use Symfony\Component\PropertyInfo\PropertyInfoExtractor; +use Symfony\Component\Serializer\Encoder\JsonEncoder; +use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata; +use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory; +use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader; +use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer; +use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer; +use Symfony\Component\Serializer\Normalizer\ObjectNormalizer; +use Symfony\Component\Serializer\Serializer; + +class StructuredOutputSerializer extends Serializer +{ + /* + * Custom serializer made to deserialize StructuredOutput. + * + * Since field name are generated by the `Symfony\AI\Platform\Contract\JsonSchema\Factory` + * without using the serializer (and the serializer metadata/attributes), we have to ignore them + * again when deserializing the data by not passing `classMetadataFactory` to ObjectNormalizer. + */ + public function __construct() + { + $classMetadataFactory = new ClassMetadataFactory(new AttributeLoader()); + $discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory); + $propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]); + + $normalizers = [ + new BackedEnumNormalizer(), + new ObjectNormalizer( + propertyTypeExtractor: $propertyInfo, + classDiscriminatorResolver: $discriminator, + ), + new ArrayDenormalizer(), + ]; + + parent::__construct($normalizers, [new JsonEncoder()]); + } +} diff --git a/src/platform/src/StructuredOutput/PlatformSubscriber.php b/src/platform/src/StructuredOutput/PlatformSubscriber.php index 0a66d7207..55a67bf64 100644 --- a/src/platform/src/StructuredOutput/PlatformSubscriber.php +++ b/src/platform/src/StructuredOutput/PlatformSubscriber.php @@ -17,18 +17,8 @@ use Symfony\AI\Platform\Exception\InvalidArgumentException; use Symfony\AI\Platform\Exception\MissingModelSupportException; use Symfony\AI\Platform\Result\DeferredResult; +use Symfony\AI\Platform\Serializer\StructuredOutputSerializer; use Symfony\Component\EventDispatcher\EventSubscriberInterface; -use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor; -use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor; -use Symfony\Component\PropertyInfo\PropertyInfoExtractor; -use Symfony\Component\Serializer\Encoder\JsonEncoder; -use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata; -use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory; -use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader; -use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer; -use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer; -use Symfony\Component\Serializer\Normalizer\ObjectNormalizer; -use Symfony\Component\Serializer\Serializer; use Symfony\Component\Serializer\SerializerInterface; /** @@ -40,29 +30,13 @@ final class PlatformSubscriber implements EventSubscriberInterface private string $outputType; + private SerializerInterface $serializer; + public function __construct( private readonly ResponseFormatFactoryInterface $responseFormatFactory = new ResponseFormatFactory(), - private ?SerializerInterface $serializer = null, + ?SerializerInterface $serializer = null, ) { - if (null !== $this->serializer) { - return; - } - - $classMetadataFactory = new ClassMetadataFactory(new AttributeLoader()); - $discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory); - $propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]); - - $normalizers = [ - new BackedEnumNormalizer(), - new ObjectNormalizer( - classMetadataFactory: $classMetadataFactory, - propertyTypeExtractor: $propertyInfo, - classDiscriminatorResolver: $discriminator, - ), - new ArrayDenormalizer(), - ]; - - $this->serializer = new Serializer($normalizers, [new JsonEncoder()]); + $this->serializer = $serializer ?? new StructuredOutputSerializer(); } public static function getSubscribedEvents(): array diff --git a/src/platform/tests/StructuredOutput/PlatformSubscriberTest.php b/src/platform/tests/StructuredOutput/PlatformSubscriberTest.php index 2977ae418..f26f97ae6 100644 --- a/src/platform/tests/StructuredOutput/PlatformSubscriberTest.php +++ b/src/platform/tests/StructuredOutput/PlatformSubscriberTest.php @@ -15,6 +15,7 @@ use PHPUnit\Framework\TestCase; use Symfony\AI\Fixtures\SomeStructure; use Symfony\AI\Fixtures\StructuredOutput\MathReasoning; +use Symfony\AI\Fixtures\StructuredOutput\MathReasoningWithAttributes; use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemAge; use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemName; use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListOfPolymorphicTypesDto; @@ -35,7 +36,6 @@ use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber; use Symfony\AI\Platform\Test\PlainConverter; -use Symfony\Component\Serializer\SerializerInterface; final class PlatformSubscriberTest extends TestCase { @@ -95,12 +95,16 @@ public function testProcessOutputWithResponseFormat() $this->assertSame('data', $deferredResult->asObject()->some); } - public function testProcessOutputWithComplexResponseFormat() + /** + * @param class-string $class + */ + #[DataProvider('complexFormatDataProvider')] + public function testProcessOutputWithComplexResponseFormat(string $class) { $processor = new PlatformSubscriber(new ConfigurableResponseFormatFactory(['some' => 'format'])); $model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]); - $options = ['response_format' => MathReasoning::class]; + $options = ['response_format' => $class]; $invocationEvent = new InvocationEvent($model, new MessageBag(), $options); $processor->processInput($invocationEvent); @@ -139,7 +143,7 @@ public function testProcessOutputWithComplexResponseFormat() $deferredResult = $resultEvent->getDeferredResult(); $this->assertInstanceOf(ObjectResult::class, $result = $deferredResult->getResult()); - $this->assertInstanceOf(MathReasoning::class, $structure = $deferredResult->asObject()); + $this->assertInstanceOf($class, $structure = $deferredResult->asObject()); $this->assertInstanceOf(Metadata::class, $result->getMetadata()); $this->assertCount(5, $structure->steps); $this->assertInstanceOf(Step::class, $structure->steps[0]); @@ -151,6 +155,14 @@ public function testProcessOutputWithComplexResponseFormat() $this->assertSame(-3.75, $structure->result); } + public static function complexFormatDataProvider(): iterable + { + return [ + [MathReasoning::class], + [MathReasoningWithAttributes::class], + ]; + } + /** * @param class-string $expectedTimeStructure */ @@ -254,8 +266,7 @@ public function testProcessOutputWithCorrectPolymorphicTypesResponseFormat() public function testProcessOutputWithoutResponseFormat() { $resultFormatFactory = new ConfigurableResponseFormatFactory(); - $serializer = self::createMock(SerializerInterface::class); - $processor = new PlatformSubscriber($resultFormatFactory, $serializer); + $processor = new PlatformSubscriber($resultFormatFactory); $converter = new PlainConverter($result = new TextResult('{"some": "data"}')); $deferred = new DeferredResult($converter, new InMemoryRawResult());