Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New @strict_dataclass decorator for dataclass validation #2895

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Feb 28, 2025

Follow-up after huggingface/transformers#36329 and slack discussions (private).

The idea is to add a layer of validation on top of Python's built-in dataclasses.

Example

from huggingface_hub.utils import strict_dataclass, validated_field

def positive_int(value: int):
    if not value >= 0:
        raise ValueError(f"Value must be positive, got {value}")

@strict_dataclass
class User:
    name: str
    age: int = validated_field(positive_int)

user = User(name="John", age=30)
(...)

# assign invalid type
user.age = "31"
# huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
#   TypeError: Field 'age' expected int, got str (value: '30')

# assign invalid value
user.age = -1
#huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
#    ValueError: Value must be positive, got -1

What it does ?

  1. Provides a decorator @strict_dataclass built on top of @dataclass. When decorated, class values are validated.
  2. Fields are validated based on type annotation (str, bool, dict, etc.). Type annotation can be a deeply nested (e.g. Dict[str, List[Optional[DummyClass]] is correctly validated)
  3. User can define custom validators in addition to type check using validated_field (built on top of field)
  4. Fields are validated on value assignment meaning at initialization but also each time someone updates a value.

What it doesn't do (yet) ?

  • doesn't have the concept of "class validator" to validate all fields are coherent. To implement them, we would to execute them once in __post_init__ and then "on-demand" with a .validate() method?. We cannot run them on each field-assignment as it would prevent modifying related value (if values A and B must be coherent, we want to be able to change both A and B and then validate).

Why not chose pydantic ? (or attrs? or marshmallow_dataclass?)

  • See discussion in Question for community: We're considering adding pydantic as a base requirement to 🤗 transformers transformers#36329 related to adding pydantic as a new dependency. Would be an heavy addition + require careful logic to support both v1 and v2.
  • we do not want most of pydantic's features, especially the ones related to automatic casting, jsonschema, serializations, aliases, ...
  • we do not need to be able to instantiate a class from a dictionary
  • we do not want to mutate data. In this PR, "validation" refers to "checking if a value is valid". In Pydanctic, "validation" refers to "casting a value, possibly mutating it and then check if it's value".
  • we do not need blazing fast validation. @strict_dataclass is not meant for heavy load where performances is critical. Common use case will be to validate a model configuration (only done once and very neglectable compared to running a model). This allows us to keep code minimal.

Plan:

  • test it on real use cases (typically transformers @gante)
  • iterate on the design until we have something satisfying
  • (optional) find a good design to define "class validators".
  • (optional) add a set of generic validators that could be reused downstream
  • document it, add tests, etc.
  • merge?

We won't push for it / release it until we are sure at least the transformers use case is covered.

Notes:

This @strict_dataclass might be useful in huggingface_hub itself in the future but that's not its primary goal for now.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This man is coooking 👨🏻‍🍳

@gante
Copy link
Member

gante commented Mar 4, 2025

@Wauplin A transformers MVP is implemented here: huggingface/transformers#36534

I've run into a constraint: some validated fields are class-dependent multiple-choice values, or have an admissible range. We can write a validator for each of these cases but, ideally, a single composable validator would be shared across models.

An example:

(validator)

def choice_str(value: str, choices: Iterable[str]):
    """Ensures that `value` is one of the choices in `choices`."""
    if value not in choices:
        raise ValueError(f"Value must be one of {choices}, got {value}")

(validated field, what I would like to be able to do)

    position_embedding_type: str = validated_field(
        choice_str, choices=["absolute", "relative_key", "relative_key_query"], default="absolute"
    )

An alternative would be to define the model-specific validator as

albert_position_embedding_type_validator = functools.partial(choice_str, choices=...)

But I was wondering whether it would be possible to pipe validator arguments from validated_field() into the validators 🤗

@Wauplin
Copy link
Contributor Author

Wauplin commented Mar 4, 2025

I would prefer to not to pipe arguments into the validator on demand. It brings a level of complexity more to validated_field (i.e. checking if more args, checking if the callable accepts them, etc.) and can be avoided quite easily as you mentioned. In practice, I'm not a fan of partial which can make things more difficult to read IMO but what you can do is define your validator like this:

def one_of(choices: List):

    def _inner(value: Any):
        """Check `value` is one of the options in `choices`."""
        if value not in choices:
            raise ValueError(f"Value must be one of {choices}, got {value}.")

    return _inner

@strict_dataclass
class AlbertConfig:
    position_embedding_type: str = validated_field(one_of(["absolute", "relative_key", "relative_key_query"]), default="absolute")
    ...

I've renamed choice_str into one_of to be more generic (the str type is already validated with the type annotation). And I made it to return a method directly. In practice it's the same as the partial you've suggested but more Pythonic I feel.

@Wauplin
Copy link
Contributor Author

Wauplin commented Mar 4, 2025

@gante I pushed b1720fa to support Literal type annotation (I just realized it wasn't the case yet). My comment above is still valid if you want to support a dynamic list of choices but in case your list of choices is static, it's much better to embed it directly in the type annotation like this:

@strict_dataclass
class AlbertConfig:
    position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
    ...

(much better for readability + IDE autocompletion)

Copy link
Contributor

@hanouticelina hanouticelina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice and clean PR 🔥
(I had to update my review, I just saw you added Literal type annotation 👍 )

@gante
Copy link
Member

gante commented Mar 7, 2025

Literal is very cool!

I'm going to redo the transformers-side PR with the latest suggestions 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants