Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nested discriminated union schema gen, pt 2 #8932

Merged
merged 3 commits into from Mar 4, 2024

Conversation

sydney-runkle
Copy link
Member

@sydney-runkle sydney-runkle commented Mar 1, 2024

Trying to keep this change as simple as possible -- more _generate_schema.py refactoring to come in future PRs.

This removes the usage of metadata keys to manage tagged union application in the latter part of schema simplification. I've made a comment on the most important line that fixed the bug we were experiencing with recursive discriminated unions.

See the code snippet below as an example of:
a) vastly increased performance compared to the faulty 2.6 logic
b) correct schema generation

from __future__ import annotations
import time
from typing import Literal, Annotated

from pydantic import Field, TypeAdapter, BaseModel
from pydantic._internal._core_utils import pretty_print_core_schema

class NestedState(BaseModel):
    state_type: Literal["nested"]
    substate: AnyState

class LoopState(BaseModel):
    state_type: Literal["loop"]
    substate: AnyState

class LeafState(BaseModel):
    state_type: Literal["leaf"]

AnyState = Annotated[NestedState | LoopState | LeafState, Field(..., discriminator="state_type")]

def build_nested_state(n):
    if n <= 0:
        return {"state_type": "leaf"}
    else:
        return {"state_type": "loop", "substate": {"state_type": "nested", "substate": build_nested_state(n-1)}}
        
adapter = TypeAdapter(AnyState)

start = time.time()
adapter.validate_python(build_nested_state(9))
print(time.time() - start)
#> 6.818771362304688e-05

start = time.time()
adapter.validate_python(build_nested_state(10))
print(time.time() - start)
#> 1.6927719116210938e-05

start = time.time()
adapter.validate_python(build_nested_state(11))
print(time.time() - start)
#> 1.6927719116210938e-05

pretty_print_core_schema(adapter.core_schema, True)
"""
{
    'type': 'definitions',
    'schema': {
        'type': 'tagged-union',
        'choices': {
            'nested': {'type': 'definition-ref', 'schema_ref': '__main__.NestedState:4629701792'},
            'loop': {'type': 'definition-ref', 'schema_ref': '__main__.LoopState:4898226368'},
            'leaf': {'type': 'definition-ref', 'schema_ref': '__main__.LeafState:4898242800'}
        },
        'discriminator': 'state_type',
        'strict': False,
        'from_attributes': True
    },
    'definitions': [
        {
            'type': 'model',
            'cls': <class '__main__.LeafState'>,
            'schema': {'type': 'model-fields', 'fields': {'state_type': {'type': 'model-field', 'schema': {'type': 'literal', 'expected': ['leaf']}}}, 'model_name': 'LeafState'},
            'ref': '__main__.LeafState:4898242800'
        },
        {
            'type': 'model',
            'cls': <class '__main__.LoopState'>,
            'schema': {
                'type': 'model-fields',
                'fields': {
                    'state_type': {'type': 'model-field', 'schema': {'type': 'literal', 'expected': ['loop']}},
                    'substate': {
                        'type': 'model-field',
                        'schema': {
                            'type': 'default',
                            'schema': {
                                'type': 'tagged-union',
                                'choices': {
                                    'nested': {'type': 'definition-ref', 'schema_ref': '__main__.NestedState:4629701792'},
                                    'loop': {'type': 'definition-ref', 'schema_ref': '__main__.LoopState:4898226368'},
                                    'leaf': {'type': 'definition-ref', 'schema_ref': '__main__.LeafState:4898242800'}
                                },
                                'discriminator': 'state_type',
                                'strict': False,
                                'from_attributes': True
                            },
                            'default': Ellipsis
                        }
                    }
                },
                'model_name': 'LoopState'
            },
            'ref': '__main__.LoopState:4898226368'
        },
        {
            'type': 'model',
            'cls': <class '__main__.NestedState'>,
            'schema': {
                'type': 'model-fields',
                'fields': {
                    'state_type': {'type': 'model-field', 'schema': {'type': 'literal', 'expected': ['nested']}},
                    'substate': {
                        'type': 'model-field',
                        'schema': {
                            'type': 'default',
                            'schema': {
                                'type': 'tagged-union',
                                'choices': {
                                    'nested': {'type': 'definition-ref', 'schema_ref': '__main__.NestedState:4629701792'},
                                    'loop': {'type': 'definition-ref', 'schema_ref': '__main__.LoopState:4898226368'},
                                    'leaf': {'type': 'definition-ref', 'schema_ref': '__main__.LeafState:4898242800'}
                                },
                                'discriminator': 'state_type',
                                'strict': False,
                                'from_attributes': True
                            },
                            'default': Ellipsis
                        }
                    }
                },
                'model_name': 'NestedState'
            },
            'ref': '__main__.NestedState:4629701792'
        }
    ]
}
"""

The above schema is now actually correct, and isn't cluttered with metadata tags relating to the application of discriminated unions :)

Copy link

cloudflare-pages bot commented Mar 1, 2024

Deploying with  Cloudflare Pages  Cloudflare Pages

Latest commit: 2ad301b
Status: ✅  Deploy successful!
Preview URL: https://f3ed5104.pydantic-docs2.pages.dev
Branch Preview URL: https://fix-discriminators-for-good.pydantic-docs2.pages.dev

View logs

Copy link

codspeed-hq bot commented Mar 1, 2024

CodSpeed Performance Report

Merging #8932 will not alter performance

Comparing fix-discriminators-for-good (2ad301b) with main (4fed81b)

Summary

✅ 10 untouched benchmarks

@sydney-runkle sydney-runkle added the relnotes-fix Used for bugfixes. label Mar 1, 2024
@sydney-runkle
Copy link
Member Author

Please review


s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually get rid of this once we use it

elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important!!! If a choice is of type definition-ref, we want to just reuse that ref for the given choice. Before, we were going through this whole thing of fetching the value from definitions, then using that, but that ends up not working for nested / recursive schemas.

Our schema walking logic walks through both the schema and the definitions, so we can rest easy knowing that unions will be converted to tagged unions in the definitions list as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with JSON schema generation if e.g. the ref'ed schema is itself a discriminated union? Maybe that can't happen, and either way this seems like an improvement if no tests fail, but still

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I believe that the walk core schema logic handles definitions schemas such that the function being applied during the walk is also applied to all of the definitions in a definitions schema, so that's why I felt comfortable removing this step.

This can be seen via the example test I added - discriminated union transformation logic is applied to the 2 schemas in the definitions list that require said changes!

@sydney-runkle
Copy link
Member Author

I'll note, we do trace all of the way down each schema now to check for discriminated unions, which isn't the most performant, although we certainly get some nice benefits here from using refs instead of constantly pulling from definitions and then having to simplify nasty schemas.

I think it'd be beneficial for us to have some codecov tests for schema building, etc for PRs like this one.

@alexmojaki alexmojaki removed their assignment Mar 2, 2024
@@ -1868,8 +1868,6 @@ class LeafState(BaseModel):
state_type: Literal['leaf']

AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')]
NestedState.model_rebuild()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you expand this section, you can see that this test showcases the example I've shown in the PR description 👍

Copy link
Member

@adriangb adriangb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! The edited test shows that a model rebuild is no longer necessary, and hence this does fix things.

@sydney-runkle
Copy link
Member Author

I've added a few tests showing the now-correct behavior for the example given in #7487, along with some extra variants for BaseModel and the pydantic dataclass.

Though these tests could be consolidated a bit with parametrization, I think that each of these test cases is simple enough that it's fine to leave them as singular tests that can easily be run independently.

@sydney-runkle
Copy link
Member Author

Regarding my comment about other high-priority fixes for discriminated unions, this was one I had in mind: #8628.

I think we could include that fix + this improvement in 2.6.4.

@sydney-runkle sydney-runkle merged commit af41c8e into main Mar 4, 2024
54 checks passed
@sydney-runkle sydney-runkle deleted the fix-discriminators-for-good branch March 4, 2024 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants