From de0b76d05f9ac0987fbb6a6105db32629b831af2 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sun, 6 Feb 2022 20:39:55 +0100 Subject: [PATCH] Remove cast to builtin types in deserialization methods Cast was added in order allow Cython using optimized operations on containers. However, quick benchmark showed that it was in fact surprisingly slower for small lists. --- apischema/deserialization/methods.py | 98 +++++++++++++--------------- 1 file changed, 44 insertions(+), 54 deletions(-) diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py index a1d43723..f68d5050 100644 --- a/apischema/deserialization/methods.py +++ b/apischema/deserialization/methods.py @@ -257,15 +257,14 @@ class ListCheckOnlyMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - data2: list = data elt_errors = None - for i, elt in enumerate(data2): + for i, elt in enumerate(data): try: self.value_method.deserialize(elt) except ValidationError as err: elt_errors = set_child_error(elt_errors, i, err) - validate_constraints(data2, self.constraints, elt_errors) - return data2 + validate_constraints(data, self.constraints, elt_errors) + return data @dataclass @@ -276,15 +275,14 @@ class ListMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - data2: list = data elt_errors = None - values: list = [None] * len(data2) - for i, elt in enumerate(data2): + values: list = [None] * len(data) + for i, elt in enumerate(data): try: values[i] = self.value_method.deserialize(elt) except ValidationError as err: elt_errors = set_child_error(elt_errors, i, err) - validate_constraints(data2, self.constraints, elt_errors) + validate_constraints(data, self.constraints, elt_errors) return values @@ -296,15 +294,14 @@ class SetMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - data2: list = data elt_errors: dict = {} values: set = set() - for i, elt in enumerate(data2): + for i, elt in enumerate(data): try: values.add(self.value_method.deserialize(elt)) except ValidationError as err: elt_errors = set_child_error(elt_errors, i, err) - validate_constraints(data2, self.constraints, elt_errors) + validate_constraints(data, self.constraints, elt_errors) return values @@ -355,16 +352,15 @@ class MappingCheckOnly(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - data2: dict = data item_errors = None - for key, value in data2.items(): + for key, value in data.items(): try: self.key_method.deserialize(key) self.value_method.deserialize(value) except ValidationError as err: item_errors = set_child_error(item_errors, key, err) - validate_constraints(data2, self.constraints, item_errors) - return data2 + validate_constraints(data, self.constraints, item_errors) + return data @dataclass @@ -376,17 +372,16 @@ class MappingMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - data2: dict = data item_errors = None items: dict = {} - for key, value in data2.items(): + for key, value in data.items(): try: items[self.key_method.deserialize(key)] = self.value_method.deserialize( value ) except ValidationError as err: item_errors = set_child_error(item_errors, key, err) - validate_constraints(data2, self.constraints, item_errors) + validate_constraints(data, self.constraints, item_errors) return items @@ -487,15 +482,14 @@ class SimpleObjectMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - data2: dict = data fields_count = 0 field_errors = None for i in range(len(self.fields)): field: Field = self.fields[i] - if field.alias in data2: + if field.alias in data: fields_count += 1 try: - field.method.deserialize(data2[field.alias]) + field.method.deserialize(data[field.alias]) except ValidationError as err: if field.required or not field.fall_back_on_default: field_errors = set_child_error(field_errors, field.alias, err) @@ -503,14 +497,14 @@ def deserialize(self, data: Any) -> Any: field_errors = set_child_error( field_errors, field.alias, ValidationError(self.missing) ) - if len(data2) != fields_count and not self.typed_dict: - for key in data2.keys() - self.all_aliases: + if len(data) != fields_count and not self.typed_dict: + for key in data.keys() - self.all_aliases: field_errors = set_child_error( field_errors, key, ValidationError(self.unexpected) ) if field_errors: raise ValidationError([], field_errors) - return self.constructor.construct(data2) + return self.constructor.construct(data) def extend_errors( @@ -564,7 +558,6 @@ def __post_init__(self): def deserialize(self, data: Any) -> Any: if not isinstance(data, dict): raise bad_type(data, dict) - data2: dict = data values: dict = {} fields_count = 0 errors = None @@ -577,7 +570,7 @@ def deserialize(self, data: Any) -> Any: field: Field = self.fields[i] if field.required: try: - value: object = data2[field.alias] + value: object = data[field.alias] except KeyError: field_errors = set_child_error( field_errors, field.alias, ValidationError(self.missing) @@ -588,29 +581,29 @@ def deserialize(self, data: Any) -> Any: values[field.name] = field.method.deserialize(value) except ValidationError as err: field_errors = set_child_error(field_errors, field.alias, err) - elif field.alias in data2: + elif field.alias in data: fields_count += 1 try: - values[field.name] = field.method.deserialize(data2[field.alias]) + values[field.name] = field.method.deserialize(data[field.alias]) except ValidationError as err: if not field.fall_back_on_default: field_errors = set_child_error(field_errors, field.alias, err) elif field.required_by is not None and not field.required_by.isdisjoint( - data2 + data ): - requiring: list = sorted(field.required_by & data2.keys()) + requiring: list = sorted(field.required_by & data.keys()) msg: str = self.missing + f" (required by {requiring})" field_errors = set_child_error( field_errors, field.alias, ValidationError([msg]) ) if self.aggregate_fields: - remain = data2.keys() - self.all_aliases + remain = data.keys() - self.all_aliases for i in range(len(self.flattened_fields)): flattened_field: FlattenedField = self.flattened_fields[i] flattened: dict = { - alias: data2[alias] + alias: data[alias] for alias in flattened_field.aliases - if alias in data2 + if alias in data } remain.difference_update(flattened) try: @@ -626,9 +619,7 @@ def deserialize(self, data: Any) -> Any: for i in range(len(self.pattern_fields)): pattern_field: PatternField = self.pattern_fields[i] matched: dict = { - key: data2[key] - for key in remain - if pattern_field.pattern.match(key) + key: data[key] for key in remain if pattern_field.pattern.match(key) } remain.difference_update(matched) try: @@ -642,7 +633,7 @@ def deserialize(self, data: Any) -> Any: field_errors, err.children ) if self.additional_field is not None: - additional: dict = {key: data2[key] for key in remain} + additional: dict = {key: data[key] for key in remain} try: values[ self.additional_field.name @@ -662,17 +653,17 @@ def deserialize(self, data: Any) -> Any: ) elif self.typed_dict: for key in remain: - values[key] = data2[key] - elif len(data2) != fields_count: + values[key] = data[key] + elif len(data) != fields_count: if not self.additional_properties: - for key in data2.keys() - self.all_aliases: + for key in data.keys() - self.all_aliases: if key != self.discriminator: field_errors = set_child_error( field_errors, key, ValidationError(self.unexpected) ) elif self.typed_dict: - for key in data2.keys() - self.all_aliases: - values[key] = data2[key] + for key in data.keys() - self.all_aliases: + values[key] = data[key] if self.validators: init = None if self.init_defaults: @@ -796,12 +787,12 @@ class TupleMethod(DeserializationMethod): def deserialize(self, data: Any) -> Any: if not isinstance(data, list): raise bad_type(data, list) - data2: list = data - if len(data2) != len(self.elt_methods): - if len(data2) < len(self.elt_methods): - raise ValidationError(format_error(self.min_len_error, data2)) - elif len(data2) > len(self.elt_methods): - raise ValidationError(format_error(self.max_len_error, data2)) + data_len = len(data) + if data_len != len(self.elt_methods): + if data_len < len(self.elt_methods): + raise ValidationError(format_error(self.min_len_error, data)) + elif data_len > len(self.elt_methods): + raise ValidationError(format_error(self.max_len_error, data)) else: raise NotImplementedError elt_errors: dict = {} @@ -809,10 +800,10 @@ def deserialize(self, data: Any) -> Any: for i in range(len(self.elt_methods)): elt_method: DeserializationMethod = self.elt_methods[i] try: - elts[i] = elt_method.deserialize(data2[i]) + elts[i] = elt_method.deserialize(data[i]) except ValidationError as err: elt_errors[i] = err - validate_constraints(data2, self.constraints, elt_errors) + validate_constraints(data, self.constraints, elt_errors) return tuple(elts) @@ -925,17 +916,16 @@ class DiscriminatorMethod(DeserializationMethod): def deserialize(self, data: Any): if not isinstance(data, dict): raise bad_type(data, dict) - data2: dict = data - if self.alias not in data2: + if self.alias not in data: raise ValidationError([], {self.alias: ValidationError(self.missing)}) try: - method: DeserializationMethod = self.mapping[data2[self.alias]] + method: DeserializationMethod = self.mapping[data[self.alias]] except (TypeError, KeyError): raise ValidationError( [], { self.alias: ValidationError( - format_error(self.error, data2[self.alias]) + format_error(self.error, data[self.alias]) ) }, )