Skip to content

Allow disambiguating union types by unique literal types #273

Open
@phiresky

Description

@phiresky

Description

Right now, it seems like union types are only supported if (a) the options are all attrs classes or None and (b) the attrs classes have unique attributes.

I use a ton of "tagged unions", so it would be great there was support for those. (They have one attribute that has an unambiguous value). Here's an example:

@attrs.define
class MatchedThing:
    matched: Literal[True] = True
    xxxxxx: str


@attrs.define
class UnmatchedThing:
    matched: Literal[False] = False
    yyyyyy: str


@attrs.define
class MaybeSomething:
    something: MatchedThing | UnmatchedThing

In this example, I'm using a literal boolean but much more common is to use a literal string, for example type: Literal["request", "response", "error"]. Other possibilities are literal integers (covered by the below code) and literal enum values (not sure if those work with the below code)

This pattern is very common in JSON files.

Here's an implementation that seems to work for me (attrs classes only though the same should work with typing.get_type_hints(anything))`:

def create_literal_field_disambiguator_function(*classes: type) -> Callable[..., Any]:
    """Given attr classes, generate a disambiguation function.
    The function is based a common field with unique literal value types."""
    common_attrs: set[str] = set.intersection(
        *(set(at.name for at in attrs.fields(get_origin(cl) or cl)) for cl in classes)
    )

    def check_attr_is_unique(attr: str) -> dict[str, type] | None:
        attr_value_to_class = {}
        for cl in classes:
            field_type = getattr(attrs.fields(get_origin(cl) or cl), attr).type
            if not cattrs.converters.is_literal(field_type):
                # type is not a literal
                return None
            for value in typing.get_args(field_type):
                if value in attr_value_to_class:
                    # this value is ambiguous -> attribute not usable for disambiguation
                    return None
                attr_value_to_class[value] = cl
        return attr_value_to_class

    unique_attr_value_to_class = None
    # try each attribute to see if the types are Literal and the values are unique
    for attr in common_attrs:
        unique_attr_value_to_class = check_attr_is_unique(attr)
        if unique_attr_value_to_class:
            # found unique attribute
            break
    else:
        raise ValueError(
            f"{classes} cannot be disambiguated with a common literal attribute. checked types of common fields: {common_attrs}"
        )

    def dis_func(data: Mapping, type: Any) -> Optional[type]:
        if not isinstance(data, Mapping):
            raise ValueError("Only input mappings are supported.")
        return unique_attr_value_to_class[data[attr]]

    return dis_func

...
def get_converter():
    ...
    converter.register_structure_hook_factory(
        lambda t: cattrs.converters.is_attrs_union(t),
        lambda t: create_literal_field_disambiguator_function(*typing.get_args(t)),
    )
...

It should probably be added to the existing disambiguation.py machinery, but those parts aren't publicly exposed

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions