diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 64219a08..3559e60d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -16,3 +16,7 @@ updates: directory: "benchmark" schedule: interval: "daily" + - package-ecosystem: "pip" + directory: "scripts" + schedule: + interval: "daily" diff --git a/apischema/conversions/converters.py b/apischema/conversions/converters.py index e26deec9..0fa1c722 100644 --- a/apischema/conversions/converters.py +++ b/apischema/conversions/converters.py @@ -7,9 +7,9 @@ TYPE_CHECKING, Any, Callable, - List, MutableMapping, Optional, + Tuple, Type, TypeVar, Union, @@ -37,8 +37,8 @@ pass -_deserializers: MutableMapping[AnyType, List[ConvOrFunc]] = CacheAwareDict( - defaultdict(list) +_deserializers: MutableMapping[AnyType, Tuple[ConvOrFunc, ...]] = CacheAwareDict( + defaultdict(tuple) ) _serializers: MutableMapping[AnyType, ConvOrFunc] = CacheAwareDict({}) Deserializer = TypeVar( @@ -54,7 +54,7 @@ def default_deserialization(tp): return _deserializers.get(tp) else: - default_deserialization = _deserializers.get # type: ignore + default_deserialization = _deserializers.get def default_serialization(tp: Type) -> Optional[AnyConversion]: @@ -83,7 +83,7 @@ def check_converter_type(tp: AnyType) -> AnyType: def _add_deserializer(conversion: ConvOrFunc, target: AnyType): target = check_converter_type(target) if conversion not in _deserializers[target]: - _deserializers[target].append(conversion) + _deserializers[target] = *_deserializers[target], conversion class DeserializerDescriptor(MethodWrapper[staticmethod]): diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 4dd03cbb..3f867273 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -72,6 +72,7 @@ OptionalMethod, PatternField, RawConstructor, + RawConstructorCopy, RecMethod, SetMethod, SimpleObjectMethod, @@ -154,6 +155,26 @@ def check_only(method: DeserializationMethod) -> bool: ) +def is_raw_dataclass(cls: type) -> bool: + return ( + dataclasses.is_dataclass(cls) + and type(cls) == type # no metaclass + and "__slots__" not in cls.__dict__ + and not hasattr(cls, "__post_init__") + and all(f.init for f in dataclasses.fields(cls)) + and cls.__new__ is object.__new__ + and ( + cls.__setattr__ is object.__setattr__ + or getattr(cls, dataclasses._PARAMS).frozen # type: ignore + ) + and ( + list(inspect.signature(cls.__init__, follow_wrapped=False).parameters) # type: ignore + == ["__dataclass_self__" if "self" in dataclasses.fields(cls) else "self"] + + [f.name for f in dataclasses.fields(cls)] + ) + ) + + @dataclasses.dataclass(frozen=True) class DeserializationMethodFactory: factory: Factory @@ -481,31 +502,12 @@ def factory( ) object_constraints = constraints_validators(constraints)[dict] all_alliases = set(alias_by_name.values()) - constructor: Constructor + constructor: Optional[Constructor] = None if is_typed_dict(cls): constructor = NoConstructor(cls) elif ( settings.deserialization.override_dataclass_constructors - and dataclasses.is_dataclass(cls) - and "__slots__" not in cls.__dict__ - and not hasattr(cls, "__post_init__") - and all(f.init for f in dataclasses.fields(cls)) - and cls.__new__ is object.__new__ - and ( - cls.__setattr__ is object.__setattr__ - or getattr(cls, dataclasses._PARAMS).frozen # type: ignore - ) - and ( - list( - inspect.signature(cls.__init__, follow_wrapped=False).parameters - ) - == [ - "__dataclass_self__" - if "self" in dataclasses.fields(cls) - else "self" - ] - + [f.name for f in dataclasses.fields(cls)] - ) + and is_raw_dataclass(cls) ): constructor = FieldsConstructor( cls, @@ -521,8 +523,6 @@ def factory( if f.default_factory is not dataclasses.MISSING ), ) - else: - constructor = RawConstructor(cls) if ( not object_constraints and not flattened_fields @@ -545,7 +545,7 @@ def factory( ) ): return SimpleObjectMethod( - constructor, + constructor or RawConstructorCopy(cls), tuple(normal_fields), all_alliases, is_typed_dict(cls), @@ -553,7 +553,7 @@ def factory( settings.errors.unexpected_property, ) return ObjectMethod( - constructor, + constructor or RawConstructor(cls), object_constraints, tuple(normal_fields), tuple(flattened_fields), diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py index f68d5050..c582d398 100644 --- a/apischema/deserialization/methods.py +++ b/apischema/deserialization/methods.py @@ -41,7 +41,7 @@ def validate(self, data: Any) -> bool: class MinimumConstraint(Constraint): minimum: int - def validate(self, data: int) -> bool: + def validate(self, data: Any) -> bool: return data >= self.minimum @@ -49,7 +49,7 @@ def validate(self, data: int) -> bool: class MaximumConstraint(Constraint): maximum: int - def validate(self, data: int) -> bool: + def validate(self, data: Any) -> bool: return data <= self.maximum @@ -57,7 +57,7 @@ def validate(self, data: int) -> bool: class ExclusiveMinimumConstraint(Constraint): exc_min: int - def validate(self, data: int) -> bool: + def validate(self, data: Any) -> bool: return data > self.exc_min @@ -65,7 +65,7 @@ def validate(self, data: int) -> bool: class ExclusiveMaximumConstraint(Constraint): exc_max: int - def validate(self, data: int) -> bool: + def validate(self, data: Any) -> bool: return data < self.exc_max @@ -73,7 +73,7 @@ def validate(self, data: int) -> bool: class MultipleOfConstraint(Constraint): mult_of: int - def validate(self, data: int) -> bool: + def validate(self, data: Any) -> bool: return not (data % self.mult_of) @@ -81,7 +81,7 @@ def validate(self, data: int) -> bool: class MinLengthConstraint(Constraint): min_len: int - def validate(self, data: str) -> bool: + def validate(self, data: Any) -> bool: return len(data) >= self.min_len @@ -89,7 +89,7 @@ def validate(self, data: str) -> bool: class MaxLengthConstraint(Constraint): max_len: int - def validate(self, data: str) -> bool: + def validate(self, data: Any) -> bool: return len(data) <= self.max_len @@ -97,7 +97,7 @@ def validate(self, data: str) -> bool: class PatternConstraint(Constraint): pattern: Pattern - def validate(self, data: str) -> bool: + def validate(self, data: Any) -> bool: return self.pattern.match(data) is not None @@ -105,7 +105,7 @@ def validate(self, data: str) -> bool: class MinItemsConstraint(Constraint): min_items: int - def validate(self, data: list) -> bool: + def validate(self, data: Any) -> bool: return len(data) >= self.min_items @@ -113,7 +113,7 @@ def validate(self, data: list) -> bool: class MaxItemsConstraint(Constraint): max_items: int - def validate(self, data: list) -> bool: + def validate(self, data: Any) -> bool: return len(data) <= self.max_items @@ -121,8 +121,8 @@ def to_hashable(data: Any) -> Any: if isinstance(data, list): return tuple(map(to_hashable, data)) elif isinstance(data, dict): - # Cython doesn't support tuple comprehension yet -> intermediate list - return tuple([(k, to_hashable(data[k])) for k in sorted(data)]) + sorted_keys = sorted(data) + return tuple(sorted_keys + [to_hashable(data[k]) for k in sorted_keys]) else: return data @@ -134,7 +134,7 @@ class UniqueItemsConstraint(Constraint): def __post_init__(self): assert self.unique - def validate(self, data: list) -> bool: + def validate(self, data: Any) -> bool: return len(set(map(to_hashable, data))) == len(data) @@ -142,7 +142,7 @@ def validate(self, data: list) -> bool: class MinPropertiesConstraint(Constraint): min_properties: int - def validate(self, data: dict) -> bool: + def validate(self, data: Any) -> bool: return len(data) >= self.min_properties @@ -150,7 +150,7 @@ def validate(self, data: dict) -> bool: class MaxPropertiesConstraint(Constraint): max_properties: int - def validate(self, data: dict) -> bool: + def validate(self, data: Any) -> bool: return len(data) <= self.max_properties @@ -158,13 +158,16 @@ def format_error(err: Union[str, Callable[[Any], str]], data: Any) -> str: return err if isinstance(err, str) else err(data) +ErrorDict = Dict[ErrorKey, ValidationError] + + def validate_constraints( - data: Any, constraints: Tuple[Constraint, ...], children_errors: Optional[dict] + data: Any, constraints: Tuple[Constraint, ...], children_errors: Optional[ErrorDict] ) -> Any: for i in range(len(constraints)): constraint: Constraint = constraints[i] if not constraint.validate(data): - errors: list = [format_error(constraint.error, data)] + errors: List[str] = [format_error(constraint.error, data)] for j in range(i + 1, len(constraints)): constraint = constraints[j] if not constraint.validate(data): @@ -176,10 +179,8 @@ def validate_constraints( def set_child_error( - errors: Optional[Dict[ErrorKey, ValidationError]], - key: ErrorKey, - error: ValidationError, -): + errors: Optional[ErrorDict], key: ErrorKey, error: ValidationError +) -> ErrorDict: if errors is None: return {key: error} else: @@ -257,7 +258,7 @@ class ListCheckOnlyMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - elt_errors = None + elt_errors: Optional[ErrorDict] = None for i, elt in enumerate(data): try: self.value_method.deserialize(elt) @@ -275,7 +276,7 @@ class ListMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - elt_errors = None + elt_errors: Optional[ErrorDict] = None values: list = [None] * len(data) for i, elt in enumerate(data): try: @@ -294,7 +295,7 @@ class SetMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - elt_errors: dict = {} + elt_errors: ErrorDict = {} values: set = set() for i, elt in enumerate(data): try: @@ -352,7 +353,7 @@ class MappingCheckOnly(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - item_errors = None + item_errors: Optional[ErrorDict] = None for key, value in data.items(): try: self.key_method.deserialize(key) @@ -372,7 +373,7 @@ class MappingMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - item_errors = None + item_errors: Optional[ErrorDict] = None items: dict = {} for key, value in data.items(): try: @@ -431,7 +432,16 @@ def construct(self, fields: Dict[str, Any]) -> Any: return fields +def PyObject_Call(obj, args, kwargs): + return obj(*args, **kwargs) + + class RawConstructor(Constructor): + def construct(self, fields: Dict[str, Any]) -> Any: + return PyObject_Call(self.cls, (), fields) + + +class RawConstructorCopy(Constructor): def construct(self, fields: Dict[str, Any]) -> Any: return self.cls(**fields) @@ -454,18 +464,16 @@ class FieldsConstructor(Constructor): default_fields: Tuple[DefaultField, ...] factory_fields: Tuple[FactoryField, ...] - def construct(self, fields: Dict[str, Any]) -> Any: - obj: object = object.__new__(self.cls) + def construct(self, fields: Any) -> Any: # fields can be a dict subclass + obj = object.__new__(self.cls) obj_dict: dict = obj.__dict__ obj_dict.update(fields) if len(fields) != self.nb_fields: - for i in range(len(self.default_fields)): - default_field: DefaultField = self.default_fields[i] - if default_field.name not in fields: + for default_field in self.default_fields: + if default_field.name not in obj_dict: obj_dict[default_field.name] = default_field.default_value - for i in range(len(self.factory_fields)): - factory_field: FactoryField = self.factory_fields[i] - if factory_field.name not in fields: + for factory_field in self.factory_fields: + if factory_field.name not in obj_dict: obj_dict[factory_field.name] = factory_field.factory() return obj @@ -482,10 +490,9 @@ class SimpleObjectMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - fields_count = 0 - field_errors = None - for i in range(len(self.fields)): - field: Field = self.fields[i] + fields_count: int = 0 + field_errors: Optional[dict] = None + for field in self.fields: if field.alias in data: fields_count += 1 try: @@ -559,47 +566,34 @@ def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) values: dict = {} - fields_count = 0 - errors = None + fields_count: int = 0 + errors: Optional[list] = None try: validate_constraints(data, self.constraints, None) except ValidationError as err: errors = list(err.messages) - field_errors = None - for i in range(len(self.fields)): - field: Field = self.fields[i] - if field.required: - try: - value: object = data[field.alias] - except KeyError: - field_errors = set_child_error( - field_errors, field.alias, ValidationError(self.missing) - ) - else: - fields_count += 1 - try: - values[field.name] = field.method.deserialize(value) - except ValidationError as err: - field_errors = set_child_error(field_errors, field.alias, err) - elif field.alias in data: + field_errors: Optional[dict] = None + for field in self.fields: + if field.alias in data: fields_count += 1 try: values[field.name] = field.method.deserialize(data[field.alias]) except ValidationError as err: - if not field.fall_back_on_default: + if field.required or not field.fall_back_on_default: field_errors = set_child_error(field_errors, field.alias, err) + elif field.required: + field_errors = set_child_error( + field_errors, field.alias, ValidationError(self.missing) + ) elif field.required_by is not None and not field.required_by.isdisjoint( data ): - requiring: list = sorted(field.required_by & data.keys()) - msg: str = self.missing + f" (required by {requiring})" - field_errors = set_child_error( - field_errors, field.alias, ValidationError([msg]) - ) + requiring = sorted(field.required_by & data.keys()) + error = ValidationError([self.missing + f" (required by {requiring})"]) + field_errors = set_child_error(field_errors, field.alias, error) if self.aggregate_fields: remain = data.keys() - self.all_aliases - for i in range(len(self.flattened_fields)): - flattened_field: FlattenedField = self.flattened_fields[i] + for flattened_field in self.flattened_fields: flattened: dict = { alias: data[alias] for alias in flattened_field.aliases @@ -616,8 +610,7 @@ def deserialize(self, data: Any) -> Any: field_errors = update_children_errors( field_errors, err.children ) - for i in range(len(self.pattern_fields)): - pattern_field: PatternField = self.pattern_fields[i] + for pattern_field in self.pattern_fields: matched: dict = { key: data[key] for key in remain if pattern_field.pattern.match(key) } @@ -683,16 +676,15 @@ def deserialize(self, data: Any) -> Any: error = ValidationError(errors or [], field_errors or {}) invalid_fields = self.post_init_modified if field_errors: - invalid_fields |= field_errors.keys() + invalid_fields = invalid_fields | field_errors.keys() try: - valid_validators = [ - v - for v in validators - if v.dependencies.isdisjoint(invalid_fields) - ] validate( ValidatorMock(self.constructor.cls, values), - valid_validators, + [ + v + for v in validators + if v.dependencies.isdisjoint(invalid_fields) + ], init, aliaser=self.aliaser, ) @@ -715,7 +707,7 @@ def deserialize(self, data: Any) -> Any: class IntMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: - if not isinstance(data, int): + if not isinstance(data, int) or isinstance(data, bool): raise bad_type(data, int) return data @@ -795,14 +787,13 @@ def deserialize(self, data: Any) -> Any: raise ValidationError(format_error(self.max_len_error, data)) else: raise NotImplementedError - elt_errors: dict = {} + elt_errors: Optional[ErrorDict] = None elts: list = [None] * len(self.elt_methods) - for i in range(len(self.elt_methods)): - elt_method: DeserializationMethod = self.elt_methods[i] + for i, elt_method in enumerate(self.elt_methods): try: elts[i] = elt_method.deserialize(data[i]) except ValidationError as err: - elt_errors[i] = err + set_child_error(elt_errors, i, err) validate_constraints(data, self.constraints, elt_errors) return tuple(elts) @@ -845,8 +836,7 @@ class UnionMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: error = None - for i in range(len(self.alt_methods)): - alt_method: DeserializationMethod = self.alt_methods[i] + for i, alt_method in enumerate(self.alt_methods): try: return alt_method.deserialize(data) except ValidationError as err: @@ -886,9 +876,8 @@ class ConversionUnionMethod(DeserializationMethod): alternatives: Tuple[ConversionAlternative, ...] def deserialize(self, data: Any) -> Any: - error: Optional[ValidationError] = None - for i in range(len(self.alternatives)): - alternative: ConversionAlternative = self.alternatives[i] + error = None + for alternative in self.alternatives: try: value = alternative.method.deserialize(data) except ValidationError as err: diff --git a/apischema/json_schema/types.py b/apischema/json_schema/types.py index c12ea924..cf83e664 100644 --- a/apischema/json_schema/types.py +++ b/apischema/json_schema/types.py @@ -32,28 +32,27 @@ class JsonType(str, Enum): @staticmethod def from_type(cls: Type) -> "JsonType": - return TYPE_TO_JSON_TYPE[cls] + try: + return TYPE_TO_JSON_TYPE[cls] + except KeyError: # pragma: no cover + raise TypeError(f"Invalid JSON type {cls}") + + def __repr__(self): + return f"'{self.value}'" # pragma: no cover def __str__(self): return self.value -class JsonTypes(Dict[type, JsonType]): - def __missing__(self, key): - raise TypeError(f"Invalid JSON type {key}") - - -TYPE_TO_JSON_TYPE = JsonTypes( - { - NoneType: JsonType.NULL, - bool: JsonType.BOOLEAN, - str: JsonType.STRING, - int: JsonType.INTEGER, - float: JsonType.NUMBER, - list: JsonType.ARRAY, - dict: JsonType.OBJECT, - } -) +TYPE_TO_JSON_TYPE = { + NoneType: JsonType.NULL, + bool: JsonType.BOOLEAN, + str: JsonType.STRING, + int: JsonType.INTEGER, + float: JsonType.NUMBER, + list: JsonType.ARRAY, + dict: JsonType.OBJECT, +} def bad_type(data: Any, *expected: type) -> ValidationError: diff --git a/apischema/serialization/methods.py b/apischema/serialization/methods.py index 1506d586..4dd304a4 100644 --- a/apischema/serialization/methods.py +++ b/apischema/serialization/methods.py @@ -72,7 +72,9 @@ class AnyMethod(SerializationMethod): factory: Callable[[AnyType], SerializationMethod] def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any: - method = self.factory(obj.__class__) # tmp variable for substitution + method: SerializationMethod = self.factory( + obj.__class__ + ) # tmp variable for substitution return method.serialize(obj, path) @@ -233,9 +235,7 @@ def update_result(self, obj: Any, result: dict): if self.typed_dict else (not self.exclude_unset or self.name in getattr(obj, FIELDS_SET_ATTR)) ): - value: object = ( - obj[self.name] if self.typed_dict else getattr(obj, self.name) - ) + value = obj[self.name] if self.typed_dict else getattr(obj, self.name) if not self.skippable or not ( (self.skip_if is not None and self.skip_if(value)) or (self.undefined and value is Undefined) @@ -277,8 +277,7 @@ class ObjectMethod(SerializationMethod): def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any: result: dict = {} - for i in range(len(self.fields)): - field: BaseField = self.fields[i] + for field in self.fields: field.update_result(obj, result) return result @@ -302,8 +301,7 @@ class TupleCheckOnlyMethod(SerializationMethod): elt_methods: Tuple[SerializationMethod, ...] def serialize(self, obj: tuple, path: Union[int, str, None] = None) -> Any: - for i in range(len(self.elt_methods)): - method: SerializationMethod = self.elt_methods[i] + for i, method in enumerate(self.elt_methods): method.serialize(obj[i], i) return obj @@ -315,8 +313,7 @@ class TupleMethod(SerializationMethod): def serialize(self, obj: tuple, path: Union[int, str, None] = None) -> Any: elts: list = [None] * len(self.elt_methods) - for i in range(len(self.elt_methods)): - method: SerializationMethod = self.elt_methods[i] + for i, method in enumerate(self.elt_methods): elts[i] = method.serialize(obj[i], i) return elts @@ -371,8 +368,7 @@ class UnionMethod(SerializationMethod): fallback: Fallback def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any: - for i in range(len(self.alternatives)): - alternative: UnionAlternative = self.alternatives[i] + for alternative in self.alternatives: if isinstance(obj, alternative.cls): try: return alternative.serialize(obj, path) diff --git a/scripts/cythonize.py b/scripts/cythonize.py index db82b855..f52d0643 100755 --- a/scripts/cythonize.py +++ b/scripts/cythonize.py @@ -6,18 +6,21 @@ import re import sys from contextlib import contextmanager +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from types import FunctionType from typing import ( AbstractSet, Any, + Callable, Iterable, List, Mapping, Match, NamedTuple, Optional, + Pattern, TextIO, Tuple, Type, @@ -172,33 +175,72 @@ def module_methods(module: str) -> Mapping[str, Method]: return methods_by_name +ReRepl = Callable[[Match], str] + + +@dataclass +class LineSubstitutor: + lines: Iterable[str] + + def __call__(self, pattern: Pattern) -> Callable[[ReRepl], ReRepl]: + def decorator(repl: ReRepl) -> ReRepl: + self.lines = (re.sub(pattern, repl, l) for l in self.lines) + return repl + + return decorator + + def get_body(func: FunctionType, cls: Optional[type] = None) -> Iterable[str]: lines, _ = inspect.getsourcelines(func) line_iter = iter(lines) for line in line_iter: - if line.rstrip().endswith(":"): + if line.split("#")[0].rstrip().endswith(":"): break else: raise NotImplementedError + substitutor = LineSubstitutor(line_iter) if cls is not None: - def replace_super(match: Match): + @substitutor(re.compile(r"super\(\)\.(\w+)\(")) + def replace_super(match: Match) -> str: assert cls is not None super_cls = cls.__bases__[0].__name__ return f"{super_cls}_{match.group(1)}(<{super_cls}>self, " - super_regex = re.compile(r"super\(\).(\w+)\(") - line_iter = (super_regex.sub(replace_super, line) for line in line_iter) + @substitutor( + re.compile( + r"(\s+)for ((\w+) in self\.(\w+)|(\w+), (\w+) in enumerate\(self\.(\w+)\)):" + ) + ) + def replace_for_loop(match: Match) -> str: + assert cls is not None + tab = match.group(1) + index = match.group(5) or "__i" + elt = match.group(3) or match.group(6) + field = match.group(4) or match.group(7) + field_type = get_type_hints(cls)[field] + assert ( + field_type.__origin__ in (Tuple, tuple) + and field_type.__args__[1] is ... + ) + elt_type = cython_type(field_type.__args__[0], func.__module__) + return f"{tab}for {index} in range(len(self.{field})):\n{tab} {elt}: {elt_type} = self.{field}[{index}]" + + @substitutor(re.compile(r"^(\s+\w+:)([^#=]*)(?==)")) + def replace_variable_annotations(match: Match) -> str: + tp = eval(match.group(2), func.__globals__) + return match.group(1) + cython_type(tp, func.__module__) + methods = module_methods(func.__module__) + method_names = "|".join(methods) - def replace_method(match: Match): + @substitutor(re.compile(rf"([\w.]+)\.({method_names})\(")) + def replace_method(match: Match) -> str: self, name = match.groups() cls, _ = methods[name] return f"{cls.__name__}_{name}({self}, " - method_names = "|".join(methods) - method_regex = re.compile(rf"([\w\.]+)\.({method_names})\(") - return (method_regex.sub(replace_method, line) for line in line_iter) + return substitutor.lines def import_lines(path: Union[str, Path]) -> Iterable[str]: @@ -290,13 +332,15 @@ def generate(package: str) -> str: with open(pyx_file_name, "w") as pyx_file: pyx = IndentedWriter(pyx_file) pyx.writeln("cimport cython") + pyx.writeln("from cpython cimport *") pyx.writelines(import_lines(ROOT_DIR / "apischema" / package / "methods.py")) for cls in module_elements(module, type): write_class(pyx, cls) # type: ignore pyx.writeln() for func in module_elements(module, FunctionType): - write_function(pyx, func) # type: ignore - pyx.writeln() + if not func.__name__.startswith("Py"): + write_function(pyx, func) # type: ignore + pyx.writeln() methods = module_methods(module) for method in methods.values(): write_methods(pyx, method) diff --git a/scripts/cythonize.sh b/scripts/cythonize.sh index 1f829d8d..15b99fef 100755 --- a/scripts/cythonize.sh +++ b/scripts/cythonize.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -python3 -m pip install cython +python3 -m pip install -r $(dirname $0)/requirements.cython.txt $(dirname $0)/cythonize.py \ No newline at end of file diff --git a/scripts/requirements.cython.txt b/scripts/requirements.cython.txt new file mode 100644 index 00000000..becb4887 --- /dev/null +++ b/scripts/requirements.cython.txt @@ -0,0 +1 @@ +Cython==0.29.28 diff --git a/setup.cfg b/setup.cfg index 5e3ecf3e..a3c245ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [flake8] max-line-length = 88 -ignore = E203, E302, E501, W503, E731, E741 +ignore = E203, E302, E501, W503, E731, E741, F402 [isort] profile = black diff --git a/tests/integration/test_deserializer_registration_reset_deserialization_cache.py b/tests/integration/test_deserializer_registration_reset_deserialization_cache.py new file mode 100644 index 00000000..fb4250e5 --- /dev/null +++ b/tests/integration/test_deserializer_registration_reset_deserialization_cache.py @@ -0,0 +1,16 @@ +import pytest + +from apischema import ValidationError, deserialize, deserializer +from apischema.conversions import Conversion, catch_value_error + + +class Foo(int): + pass + + +def test_deserializer_registration_reset_deserialization_cache(): + assert deserialize(Foo, 1) == Foo(1) + deserializer(Conversion(catch_value_error(Foo), source=str, target=Foo)) + assert deserialize(Foo, "1") == Foo(1) + with pytest.raises(ValidationError): + deserialize(Foo, 1) diff --git a/tests/requirements.txt b/tests/requirements.txt index 016a230b..0dbebe95 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -8,4 +8,3 @@ pytest-cov==4.0.0 pytest-asyncio==0.20.3 SQLAlchemy==1.4.45 typing-extensions==4.4.0 -Cython==0.29.32 diff --git a/tests/unit/test_deserialization_serialization.py b/tests/unit/test_deserialization_serialization.py index 49184531..cd540df7 100644 --- a/tests/unit/test_deserialization_serialization.py +++ b/tests/unit/test_deserialization_serialization.py @@ -56,6 +56,10 @@ class Dataclass: opt: Optional[int] = field(default=None, metadata=schema(min=100)) +def test_bool_as_int_error(): + error(True, int) + + @pytest.mark.parametrize("data", ["", 0]) def test_any(data): bijection(Any, data, data)