Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cast to builtin types in deserialization methods #348

Merged
merged 1 commit into from
Feb 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 44 additions & 54 deletions apischema/deserialization/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,14 @@ class ListCheckOnlyMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, list):
raise bad_type(data, list)
data2: list = data
elt_errors = None
for i, elt in enumerate(data2):
for i, elt in enumerate(data):
try:
self.value_method.deserialize(elt)
except ValidationError as err:
elt_errors = set_child_error(elt_errors, i, err)
validate_constraints(data2, self.constraints, elt_errors)
return data2
validate_constraints(data, self.constraints, elt_errors)
return data


@dataclass
Expand All @@ -276,15 +275,14 @@ class ListMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, list):
raise bad_type(data, list)
data2: list = data
elt_errors = None
values: list = [None] * len(data2)
for i, elt in enumerate(data2):
values: list = [None] * len(data)
for i, elt in enumerate(data):
try:
values[i] = self.value_method.deserialize(elt)
except ValidationError as err:
elt_errors = set_child_error(elt_errors, i, err)
validate_constraints(data2, self.constraints, elt_errors)
validate_constraints(data, self.constraints, elt_errors)
return values


Expand All @@ -296,15 +294,14 @@ class SetMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, list):
raise bad_type(data, list)
data2: list = data
elt_errors: dict = {}
values: set = set()
for i, elt in enumerate(data2):
for i, elt in enumerate(data):
try:
values.add(self.value_method.deserialize(elt))
except ValidationError as err:
elt_errors = set_child_error(elt_errors, i, err)
validate_constraints(data2, self.constraints, elt_errors)
validate_constraints(data, self.constraints, elt_errors)
return values


Expand Down Expand Up @@ -355,16 +352,15 @@ class MappingCheckOnly(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, dict):
raise bad_type(data, dict)
data2: dict = data
item_errors = None
for key, value in data2.items():
for key, value in data.items():
try:
self.key_method.deserialize(key)
self.value_method.deserialize(value)
except ValidationError as err:
item_errors = set_child_error(item_errors, key, err)
validate_constraints(data2, self.constraints, item_errors)
return data2
validate_constraints(data, self.constraints, item_errors)
return data


@dataclass
Expand All @@ -376,17 +372,16 @@ class MappingMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, dict):
raise bad_type(data, dict)
data2: dict = data
item_errors = None
items: dict = {}
for key, value in data2.items():
for key, value in data.items():
try:
items[self.key_method.deserialize(key)] = self.value_method.deserialize(
value
)
except ValidationError as err:
item_errors = set_child_error(item_errors, key, err)
validate_constraints(data2, self.constraints, item_errors)
validate_constraints(data, self.constraints, item_errors)
return items


Expand Down Expand Up @@ -487,30 +482,29 @@ class SimpleObjectMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, dict):
raise bad_type(data, dict)
data2: dict = data
fields_count = 0
field_errors = None
for i in range(len(self.fields)):
field: Field = self.fields[i]
if field.alias in data2:
if field.alias in data:
fields_count += 1
try:
field.method.deserialize(data2[field.alias])
field.method.deserialize(data[field.alias])
except ValidationError as err:
if field.required or not field.fall_back_on_default:
field_errors = set_child_error(field_errors, field.alias, err)
elif field.required:
field_errors = set_child_error(
field_errors, field.alias, ValidationError(self.missing)
)
if len(data2) != fields_count and not self.typed_dict:
for key in data2.keys() - self.all_aliases:
if len(data) != fields_count and not self.typed_dict:
for key in data.keys() - self.all_aliases:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
if field_errors:
raise ValidationError([], field_errors)
return self.constructor.construct(data2)
return self.constructor.construct(data)


def extend_errors(
Expand Down Expand Up @@ -564,7 +558,6 @@ def __post_init__(self):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, dict):
raise bad_type(data, dict)
data2: dict = data
values: dict = {}
fields_count = 0
errors = None
Expand All @@ -577,7 +570,7 @@ def deserialize(self, data: Any) -> Any:
field: Field = self.fields[i]
if field.required:
try:
value: object = data2[field.alias]
value: object = data[field.alias]
except KeyError:
field_errors = set_child_error(
field_errors, field.alias, ValidationError(self.missing)
Expand All @@ -588,29 +581,29 @@ def deserialize(self, data: Any) -> Any:
values[field.name] = field.method.deserialize(value)
except ValidationError as err:
field_errors = set_child_error(field_errors, field.alias, err)
elif field.alias in data2:
elif field.alias in data:
fields_count += 1
try:
values[field.name] = field.method.deserialize(data2[field.alias])
values[field.name] = field.method.deserialize(data[field.alias])
except ValidationError as err:
if not field.fall_back_on_default:
field_errors = set_child_error(field_errors, field.alias, err)
elif field.required_by is not None and not field.required_by.isdisjoint(
data2
data
):
requiring: list = sorted(field.required_by & data2.keys())
requiring: list = sorted(field.required_by & data.keys())
msg: str = self.missing + f" (required by {requiring})"
field_errors = set_child_error(
field_errors, field.alias, ValidationError([msg])
)
if self.aggregate_fields:
remain = data2.keys() - self.all_aliases
remain = data.keys() - self.all_aliases
for i in range(len(self.flattened_fields)):
flattened_field: FlattenedField = self.flattened_fields[i]
flattened: dict = {
alias: data2[alias]
alias: data[alias]
for alias in flattened_field.aliases
if alias in data2
if alias in data
}
remain.difference_update(flattened)
try:
Expand All @@ -626,9 +619,7 @@ def deserialize(self, data: Any) -> Any:
for i in range(len(self.pattern_fields)):
pattern_field: PatternField = self.pattern_fields[i]
matched: dict = {
key: data2[key]
for key in remain
if pattern_field.pattern.match(key)
key: data[key] for key in remain if pattern_field.pattern.match(key)
}
remain.difference_update(matched)
try:
Expand All @@ -642,7 +633,7 @@ def deserialize(self, data: Any) -> Any:
field_errors, err.children
)
if self.additional_field is not None:
additional: dict = {key: data2[key] for key in remain}
additional: dict = {key: data[key] for key in remain}
try:
values[
self.additional_field.name
Expand All @@ -662,17 +653,17 @@ def deserialize(self, data: Any) -> Any:
)
elif self.typed_dict:
for key in remain:
values[key] = data2[key]
elif len(data2) != fields_count:
values[key] = data[key]
elif len(data) != fields_count:
if not self.additional_properties:
for key in data2.keys() - self.all_aliases:
for key in data.keys() - self.all_aliases:
if key != self.discriminator:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
elif self.typed_dict:
for key in data2.keys() - self.all_aliases:
values[key] = data2[key]
for key in data.keys() - self.all_aliases:
values[key] = data[key]
if self.validators:
init = None
if self.init_defaults:
Expand Down Expand Up @@ -796,23 +787,23 @@ class TupleMethod(DeserializationMethod):
def deserialize(self, data: Any) -> Any:
if not isinstance(data, list):
raise bad_type(data, list)
data2: list = data
if len(data2) != len(self.elt_methods):
if len(data2) < len(self.elt_methods):
raise ValidationError(format_error(self.min_len_error, data2))
elif len(data2) > len(self.elt_methods):
raise ValidationError(format_error(self.max_len_error, data2))
data_len = len(data)
if data_len != len(self.elt_methods):
if data_len < len(self.elt_methods):
raise ValidationError(format_error(self.min_len_error, data))
elif data_len > len(self.elt_methods):
raise ValidationError(format_error(self.max_len_error, data))
else:
raise NotImplementedError
elt_errors: dict = {}
elts: list = [None] * len(self.elt_methods)
for i in range(len(self.elt_methods)):
elt_method: DeserializationMethod = self.elt_methods[i]
try:
elts[i] = elt_method.deserialize(data2[i])
elts[i] = elt_method.deserialize(data[i])
except ValidationError as err:
elt_errors[i] = err
validate_constraints(data2, self.constraints, elt_errors)
validate_constraints(data, self.constraints, elt_errors)
return tuple(elts)


Expand Down Expand Up @@ -925,17 +916,16 @@ class DiscriminatorMethod(DeserializationMethod):
def deserialize(self, data: Any):
if not isinstance(data, dict):
raise bad_type(data, dict)
data2: dict = data
if self.alias not in data2:
if self.alias not in data:
raise ValidationError([], {self.alias: ValidationError(self.missing)})
try:
method: DeserializationMethod = self.mapping[data2[self.alias]]
method: DeserializationMethod = self.mapping[data[self.alias]]
except (TypeError, KeyError):
raise ValidationError(
[],
{
self.alias: ValidationError(
format_error(self.error, data2[self.alias])
format_error(self.error, data[self.alias])
)
},
)
Expand Down