Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Nested Enum Classes With Duplicate Names #537

Open
OtherBarry opened this issue Aug 26, 2022 · 9 comments
Open

[BUG] Nested Enum Classes With Duplicate Names #537

OtherBarry opened this issue Aug 26, 2022 · 9 comments

Comments

@OtherBarry
Copy link
Contributor

Problem
It's a pretty common pattern with Django classes to contain enums/choices classes as a nested class within the model class that uses the enum. If there are multiple nested classes with the same name, even though they are nested under different classes, the generated OpenAPI schema only uses one of them.

Using __qualname__ instead of __name__ should solve this issue for nested classes.

It might be worth looking into namespacing schemas by prefixing them with their module or django app or something, as I imagine this issue occurs with any duplicated names.

Example

class SomeSchema(Schema):
    class Status(str, Enum):
        FOO = "foo"
    status: SomeSchema.Status

class OtherSchema(Schema) 
    class Status(str, Enum):
        BAR= "bar"
    status: OtherSchema.Status

@api.post("/some")
def get_some_thing(request, data: SomeSchema):
    return data.status

@api.post("/other")
def get_other_thing(request, data: OtherSchema):
    return data.status

Only one of the status enums will be present in the resulting schema.

Versions

  • Python version: 3.10.4
  • Django version: 3.2.14
  • Django-Ninja version: 0.19.1
@jmduke
Copy link

jmduke commented Dec 4, 2022

@OtherBarry did you figure out a workaround for this? I'm running into precisely this problem and can't find a way to even monkeypatch it.

@OtherBarry
Copy link
Contributor Author

@jmduke I ended up just renaming my enums, so instead of Order.Status and Product.Status I have Order.OrderStatus and Product.ProductStatus, which solved the problem for me.

The __qualname__ fix is a pretty easy solution, but might be a big of a pain to monkeypatch. I believe the relevant function is here.

@jmduke
Copy link

jmduke commented Dec 5, 2022

Thanks! Because I'm too stubborn (and the monkeypatching only happens offline, since I'm not serving the schema dynamically) I went with that approach. Your pointer just changes the title of the object, not the ref ID; monkey-patching get_model_name_map in pydantic.schema ended up doing the trick.

@OtherBarry
Copy link
Contributor Author

OtherBarry commented Dec 5, 2022

Awesome! Are you able to make a pull request in pydantic with the change? Or just post your monkeypatch here and I'll look into it.

@OtherBarry
Copy link
Contributor Author

OtherBarry commented Dec 6, 2022

Took a look at the pydantic code and it seems like they actually handle name collisions reasonably well. The issue is that django-ninja uses model_schema() on each model, instead of using schema() on all models, so collisions aren't handled.

@vitalik is this something that's a relatively easy fix? I don't know the schema generation code at all so will take me a while to find out.

@jmduke can you post your monkeypatch here so that other people with this issue (namely me) can use it in the mean time?

@jmduke
Copy link

jmduke commented Dec 6, 2022

@OtherBarry you beat me to the explanation :) There's a couple todos here that I think feint towards the issue, but I couldn't quite wrap my head around the indirection that goes on in this module. I agree that a better approach for django-ninja to take might be delegate as much of the mapping + conflict resolution as possible to pydantic, which appears to handle it quite well.

To answer your question, though, the monkey-patch in question:

from pydantic import schema

def monkey_patched_get_model_name_map(
    unique_models: schema.TypeModelSet,
) -> dict[schema.TypeModelOrEnum, str]:
    """
    Process a set of models and generate unique names for them to be used as keys in the JSON Schema
    definitions. By default the names are the same as the class name. But if two models in different Python
    modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be
    based on the Python module path for those conflicting models to prevent name collisions.
    :param unique_models: a Python set of models
    :return: dict mapping models to names
    """
    name_model_map = {}
    conflicting_names: set[str] = set()
    for model in unique_models:
        model_name = schema.normalize_name(model.__qualname__.replace(".", ""))
        if model_name in conflicting_names:
            model_name = schema.get_long_model_name(model)
            name_model_map[model_name] = model
        elif model_name in name_model_map:
            conflicting_names.add(model_name)
            conflicting_model = name_model_map.pop(model_name)
            name_model_map[
                schema.get_long_model_name(conflicting_model)
            ] = conflicting_model
            name_model_map[schema.get_long_model_name(model)] = model
        else:
            name_model_map[model_name] = model
    return {v: k for k, v in name_model_map.items()}


schema.get_model_name_map = monkey_patched_get_model_name_map

The only line here that changes from the original is:

model_name = schema.normalize_name(model.__qualname__.replace(".", ""))

A couple notes:

  • I added the replace so as to prettify the results a bit: I wanted a ref of OrderStatus rather than Order.Status.
  • One potential solution would have been to just call schema.get_long_model_name instead of model.__qualname__. Similarly, I avoided this on aesthetic bounds because long_model_name includes the entire module path which I didn't want to expose within the ref.

@jmduke
Copy link

jmduke commented Nov 15, 2023

@vitalik You mentioned in #862 that this is no longer possible in pydantic2, and indeed after trying to migrate my setup to django-ninja@1.0rc I run into the issue as outlined in #537. Is there a recommended path forward? This is a blocker for me, and I imagine it's not a particularly uncommon use case.

@furious-luke
Copy link

furious-luke commented May 22, 2024

Edit: This patch is only partially correct, please see my next comment in addition to this one.

Hey @jmduke! I'm not sure if you're still blocked by this, but I came across the same issue in my work and put together a quick monkeypatch to temporarily resolve the issue. I'll add the monkey patch below.

In my case, I only care about resolving clashing names for nested TextChoices on my models. To that end, I've only patched the function related to generating titles for enumerations. Similar to the original patch, there is only one line changed from the original Pydantic function (it's the line containing the __qualname__ access).

Anyhow, I hope you find it useful!

from enum import Enum
import inspect
from operator import attrgetter
from typing import Any, Literal

from pydantic import ConfigDict
from pydantic_core import core_schema, CoreSchema
from pydantic.json_schema import JsonSchemaValue
import pydantic._internal._std_types_schema
from pydantic._internal._core_utils import get_type_ref
from pydantic._internal._schema_generation_shared import GetJsonSchemaHandler


def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
    cases: list[Any] = list(enum_type.__members__.values())

    enum_ref = get_type_ref(enum_type)
    description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
    if description == 'An enumeration.':  # This is the default value provided by enum.EnumMeta.__new__; don't use it
        description = None
    js_updates = {'title': enum_type.__qualname__.replace(".", ""), 'description': description}
    js_updates = {k: v for k, v in js_updates.items() if v is not None}

    sub_type: Literal['str', 'int', 'float'] | None = None
    if issubclass(enum_type, int):
        sub_type = 'int'
        value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
    elif issubclass(enum_type, str):
        # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
        sub_type = 'str'
        value_ser_type = core_schema.simple_ser_schema('str')
    elif issubclass(enum_type, float):
        sub_type = 'float'
        value_ser_type = core_schema.simple_ser_schema('float')
    else:
        # TODO this is an ugly hack, how do we trigger an Any schema for serialization?
        value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

    if cases:

        def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(schema)
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # we don't want to add the missing to the schema if it's the default one
        default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__  # type: ignore
        enum_schema = core_schema.enum_schema(
            enum_type,
            cases,
            sub_type=sub_type,
            missing=None if default_missing else enum_type._missing_,
            ref=enum_ref,
            metadata={'pydantic_js_functions': [get_json_schema]},
        )

        if config.get('use_enum_values', False):
            enum_schema = core_schema.no_info_after_validator_function(
                attrgetter('value'), enum_schema, serialization=value_ser_type
            )

        return enum_schema

    else:

        def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # Use an isinstance check for enums with no cases.
        # The most important use case for this is creating TypeVar bounds for generics that should
        # be restricted to enums. This is more consistent than it might seem at first, since you can only
        # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
        # We use the get_json_schema function when an Enum subclass has been declared with no cases
        # so that we can still generate a valid json schema.
        return core_schema.is_instance_schema(
            enum_type,
            metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
        )


pydantic._internal._std_types_schema.get_enum_core_schema = get_enum_core_schema

@furious-luke
Copy link

furious-luke commented May 22, 2024

After some more testing I found the above patch fails to correct the JSON schema definition refs. The code in Pydantic that generates, caches, and uses, these references is pretty complicated, so I'm sure there's a better way, but I've made another monkey patch to resolve the ref issue, too.

The solution I've used is rearranging the preferential order of the reference identifiers generated by Pydantic to preference the most specific option. It'll result in ugly refs, but those never actually get presented to the user anywhere I think, so shouldn't be too impactful no, they do actually appear in the JSON schema when downloaded. In my particular case it won't cause any issues, it's just ugly.

Anyway, here's the full patch, including the above one, and the additional function to correct the refs:

import re
from enum import Enum
import inspect
from operator import attrgetter
from typing import Any, Literal

from pydantic import ConfigDict
from pydantic_core import core_schema, CoreSchema
from pydantic.json_schema import JsonSchemaValue, CoreModeRef, DefsRef, _MODE_TITLE_MAPPING
import pydantic._internal._std_types_schema
from pydantic._internal._core_utils import get_type_ref
from pydantic._internal._schema_generation_shared import GetJsonSchemaHandler


def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
    cases: list[Any] = list(enum_type.__members__.values())

    enum_ref = get_type_ref(enum_type)
    description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
    if description == 'An enumeration.':  # This is the default value provided by enum.EnumMeta.__new__; don't use it
        description = None
    js_updates = {'title': enum_type.__qualname__.replace(".", ""), 'description': description}
    js_updates = {k: v for k, v in js_updates.items() if v is not None}

    sub_type: Literal['str', 'int', 'float'] | None = None
    if issubclass(enum_type, int):
        sub_type = 'int'
        value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
    elif issubclass(enum_type, str):
        # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
        sub_type = 'str'
        value_ser_type = core_schema.simple_ser_schema('str')
    elif issubclass(enum_type, float):
        sub_type = 'float'
        value_ser_type = core_schema.simple_ser_schema('float')
    else:
        # TODO this is an ugly hack, how do we trigger an Any schema for serialization?
        value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

    if cases:

        def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(schema)
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # we don't want to add the missing to the schema if it's the default one
        default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__  # type: ignore
        enum_schema = core_schema.enum_schema(
            enum_type,
            cases,
            sub_type=sub_type,
            missing=None if default_missing else enum_type._missing_,
            ref=enum_ref,
            metadata={'pydantic_js_functions': [get_json_schema]},
        )

        if config.get('use_enum_values', False):
            enum_schema = core_schema.no_info_after_validator_function(
                attrgetter('value'), enum_schema, serialization=value_ser_type
            )

        return enum_schema

    else:

        def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
            json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
            original_schema = handler.resolve_ref_schema(json_schema)
            original_schema.update(js_updates)
            return json_schema

        # Use an isinstance check for enums with no cases.
        # The most important use case for this is creating TypeVar bounds for generics that should
        # be restricted to enums. This is more consistent than it might seem at first, since you can only
        # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
        # We use the get_json_schema function when an Enum subclass has been declared with no cases
        # so that we can still generate a valid json schema.
        return core_schema.is_instance_schema(
            enum_type,
            metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
        )


pydantic._internal._std_types_schema.get_enum_core_schema = get_enum_core_schema


def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef:
    """Override this method to change the way that definitions keys are generated from a core reference.

    Args:
        core_mode_ref: The core reference.

    Returns:
        The definitions key.
    """
    # Split the core ref into "components"; generic origins and arguments are each separate components
    core_ref, mode = core_mode_ref
    components = re.split(r'([\][,])', core_ref)
    # Remove IDs from each component
    components = [x.rsplit(':', 1)[0] for x in components]
    core_ref_no_id = ''.join(components)
    # Remove everything before the last period from each "component"
    components = [re.sub(r'(?:[^.[\]]+\.)+((?:[^.[\]]+))', r'\1', x) for x in components]
    short_ref = ''.join(components)

    mode_title = _MODE_TITLE_MAPPING[mode]

    # It is important that the generated defs_ref values be such that at least one choice will not
    # be generated for any other core_ref. Currently, this should be the case because we include
    # the id of the source type in the core_ref
    name = DefsRef(self.normalize_name(short_ref))
    name_mode = DefsRef(self.normalize_name(short_ref) + f'-{mode_title}')
    module_qualname = DefsRef(self.normalize_name(core_ref_no_id))
    module_qualname_mode = DefsRef(f'{module_qualname}-{mode_title}')
    module_qualname_id = DefsRef(self.normalize_name(core_ref))
    occurrence_index = self._collision_index.get(module_qualname_id)
    if occurrence_index is None:
        self._collision_counter[module_qualname] += 1
        occurrence_index = self._collision_index[module_qualname_id] = self._collision_counter[module_qualname]

    module_qualname_occurrence = DefsRef(f'{module_qualname}__{occurrence_index}')
    module_qualname_occurrence_mode = DefsRef(f'{module_qualname_mode}__{occurrence_index}')

    self._prioritized_defsref_choices[module_qualname_occurrence_mode] = [
        module_qualname_occurrence_mode,
        name,
        name_mode,
        module_qualname,
        module_qualname_mode,
        module_qualname_occurrence,
    ]

    return module_qualname_occurrence_mode


pydantic.json_schema.GenerateJsonSchema.get_defs_ref = get_defs_ref

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants