Skip to content
This repository has been archived by the owner on Apr 10, 2023. It is now read-only.

Commit

Permalink
Model tagging when serializing and deserializing (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
rossmacarthur committed Mar 27, 2019
1 parent eee8b47 commit b2b45b8
Show file tree
Hide file tree
Showing 4 changed files with 543 additions and 92 deletions.
7 changes: 7 additions & 0 deletions src/serde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'ContextError',
'DeserializationError',
'InstantiationError',
'MetaError',
'NormalizationError',
'SerdeError',
'SerializationError',
Expand Down Expand Up @@ -73,6 +74,12 @@ class ContextError(BaseSerdeError):
"""


class MetaError(BaseSerdeError):
"""
Raised when there is a problem with the Model's Meta class.
"""


class SerdeError(BaseSerdeError):
"""
Raised when any Model stage fails.
Expand Down
251 changes: 192 additions & 59 deletions src/serde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
from serde.exceptions import (
DeserializationError,
InstantiationError,
MetaError,
MissingDependency,
NormalizationError,
SerdeError,
SerializationError,
SkipSerialization,
ValidationError
Expand Down Expand Up @@ -185,6 +185,35 @@ def __getattr__(self, name):
return super(Fields, self).__getattribute__(name)


class Meta(object):
"""
A Meta class for a Model that carries extra configuration.
"""

@staticmethod
def default_tag_map(cls):
return cls.__name__

def __init__(self, **kwargs):
defaults = {
'abstract': False,
'tag': None,
'content': None,
'tag_map': self.default_tag_map
}

for name, default in defaults.items():
setattr(self, name, kwargs.pop(name, default))

if kwargs:
raise MetaError(
'invalid Meta field{} {}'.format(
'' if len(kwargs.keys()) == 1 else 's',
', '.join('{!r}'.format(k) for k in kwargs.keys())
)
)


class ModelType(type):
"""
A metaclass for Models.
Expand All @@ -207,6 +236,15 @@ def __new__(cls, cname, bases, attrs):
Returns:
Model: a new Model class.
"""
# Handle the Meta class.
if 'Meta' in attrs:
meta = Meta(**{
k: v for k, v in attrs.pop('Meta').__dict__.items()
if not k.startswith('_')
})
else:
meta = Meta()

def is_field(key, value):
if isinstance(value, Field):
value._name = key
Expand All @@ -227,6 +265,13 @@ def is_field(key, value):
# Order the fields by the Field identifier. This gets the order that
# they were defined on the Models. We add these to the Model.
final_attrs['_fields'] = Fields(sorted(fields.items(), key=lambda x: x[1].id))
final_attrs['_meta'] = meta

# Figure out the parent.
if not (bases == (object,) and cname == 'Model'):
final_attrs['_parent'] = next(iter(b for b in bases if issubclass(b, Model)))
else:
final_attrs['_parent'] = None

return super(ModelType, cls).__new__(cls, cname, bases, final_attrs)

Expand All @@ -246,17 +291,21 @@ def __init__(self, *args, **kwargs):
Fields in the order the Fields are defined on the Model.
**kwargs: keyword argument values for each Field on the Model.
"""
if self._meta.abstract:
raise InstantiationError(
'unable to instantiate abstract Model {!r}'.format(self.__class__.__name__)
)

try:
for name, value in zip_until_right(self._fields.keys(), args):
if name in kwargs:
raise SerdeError(
raise InstantiationError(
'__init__() got multiple values for keyword argument {!r}'
.format(name)
)

kwargs[name] = value
except ValueError:
raise SerdeError(
raise InstantiationError(
'__init__() takes a maximum of {!r} positional arguments but {!r} were given'
.format(len(self._fields) + 1, len(args) + 1)
)
Expand Down Expand Up @@ -367,23 +416,7 @@ def normalize(self):
Normalize this Model.
Override this method to add any additional normalization to the Model.
This will be called after each Field has been normalized.
::
>>> class Fruit(Model):
... name = fields.Str()
... family = fields.Str()
...
... def normalize(self):
... self.name = self.name.strip()
... self.family = self.family.strip()
>>> fruit = Fruit(name='Tangerine ', family=' Citrus fruit')
>>> fruit.name
'Tangerine'
>>> fruit.family
'Citrus fruit'
This will be called after all Fields have been normalized.
"""
pass

Expand All @@ -392,51 +425,113 @@ def validate(self):
Validate this Model.
Override this method to add any additional validation to the Model. This
will be called after each Field has been validated.
will be called after all Fields have been validated.
"""
pass

@classmethod
def _variants(cls):
"""
Returns a list of variants of this Model.
"""
return cls.__subclasses__()

@classmethod
def _variant_map(cls):
"""
Returns a map of variant identifier to variant class.
"""
variants = {cls._meta.tag_map(c): c for c in cls._variants()}

::
if not cls._meta.abstract:
variants[cls._meta.tag_map(cls)] = cls

>>> class Owner(Model):
... cats_name = fields.Optional(fields.Str)
... dogs_name = fields.Optional(fields.Str)
...
... def validate(self):
... msg = 'No one has both!'
... assert not (self.cats_name and self.dogs_name), msg
...
return variants

>>> owner = Owner(cats_name='Luna', dogs_name='Max')
Traceback (most recent call last):
...
serde.exceptions.InstantiationError: No one has both!
@classmethod
def _tagged_variant_name(cls, d):
"""
pass
Returns the name of the variant based on serialized data.
"""
# Externally tagged variant
if cls._meta.tag is True:
try:
return next(iter(d))
except StopIteration:
raise DeserializationError('expected externally tagged data')
# Internally/adjacently tagged variant
try:
return d[cls._meta.tag]
except KeyError:
raise DeserializationError('expected tag {!r}'.format(cls._meta.tag))

@classmethod
def from_dict(cls, d, strict=True):
def _tagged_variant(cls, variant_name):
"""
Convert a dictionary to an instance of this Model.
Returns the variant class corresponding to the given name.
"""
try:
return cls._variant_map()[variant_name]
except KeyError:
raise DeserializationError(
'no variant found for tag {!r}'.format(variant_name)
)

Args:
d (dict): a serialized version of this Model.
strict (bool): if set to False then no exception will be raised when
unknown dictionary keys are present.
@classmethod
def _transform_tagged_data(cls, d, variant_name):
"""
Transform the tagged content so that it can be deserialized.
"""
# Externally tagged variant
if cls._meta.tag is True:
return d[variant_name]
# Adjacently tagged variant
elif cls._meta.content:
try:
return d[cls._meta.content]
except KeyError:
raise DeserializationError(
'expected adjacently tagged data under key {!r}'.format(cls._meta.content)
)
# Internally tagged variant
return {k: v for k, v in d.items() if k != cls._meta.tag}

Returns:
Model: an instance of this Model.
@classmethod
def _transform_untagged_data(cls, d, dict):
"""
Transform the untagged content into tagged data.
"""
parent = cls

while parent:
if parent._meta.tag:
variant_name = cls._meta.tag_map(cls)

# Externally tagged variant
if parent._meta.tag is True:
d = dict([(variant_name, d)])
# Adjacently tagged variant
elif parent._meta.content:
d = dict([(parent._meta.tag, variant_name), (parent._meta.content, d)])
# Internally tagged variant
else:
d_new = dict([(parent._meta.tag, variant_name)])
d_new.update(d)
d = d_new

parent = parent._parent

return d

Raises:
`~serde.exceptions.DeserializationError`: when a Field value can not
be deserialized or there are unknown dictionary keys.
@classmethod
def _from_dict(cls, d, strict=True):
"""
Convert a dictionary to an instance of this Model.
"""
self = cls.__new__(cls)

for name, field in cls._fields.items():
value = None

if field.name in d:
value = self._deserialize_field(field, d.get(field.name))

value = self._deserialize_field(field, d[field.name]) if field.name in d else None
setattr(self, name, value)

if strict:
Expand All @@ -457,6 +552,46 @@ def from_dict(cls, d, strict=True):

return self

@classmethod
def from_dict(cls, d, strict=True):
"""
Convert a dictionary to an instance of this Model.
Args:
d (dict): a serialized version of this Model.
strict (bool): if set to False then no exception will be raised when
unknown dictionary keys are present.
Returns:
Model: an instance of this Model.
"""
# Externally/internally/adjacently tagged variant
if cls._meta.tag:
variant_name = cls._tagged_variant_name(d)
variant = cls._tagged_variant(variant_name)
d = cls._transform_tagged_data(d, variant_name)

if variant != cls:
return variant.from_dict(d, strict=strict)
# Untagged variant
elif cls._meta.tag is False:
if not cls._meta.abstract:
try:
return cls._from_dict(d, strict=strict)
except DeserializationError:
pass

# Try each variant in turn until one succeeds
for variant in cls._variants():
try:
return variant.from_dict(d, strict=strict)
except DeserializationError:
pass

raise DeserializationError('no variant succeeded deserialization')

return cls._from_dict(d, strict=strict)

@classmethod
@requires_module('cbor', package='cbor2')
def from_cbor(cls, b, strict=True, **kwargs):
Expand Down Expand Up @@ -553,23 +688,21 @@ def to_dict(self, dict=None):
Returns:
dict: the Model serialized as a dictionary.
Raises:
`~serde.exceptions.SerializationError`: when a Field value cannot be
serialized.
"""
if dict is None:
dict = OrderedDict

result = dict()
d = dict()

for name, field in self._fields.items():
try:
result[field.name] = self._serialize_field(field, getattr(self, name))
d[field.name] = self._serialize_field(field, getattr(self, name))
except SkipSerialization:
pass

return result
d = self._transform_untagged_data(d, dict=dict)

return d

@requires_module('cbor', package='cbor2')
def to_cbor(self, dict=None, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def decorated_function(*args, **kwargs):
return decorated_function


def py3_only(f):
def py3(f):
def decorated_function(*args, **kwargs):
if six.PY2:
return
Expand All @@ -23,7 +23,7 @@ def decorated_function(*args, **kwargs):
return decorated_function


def py2_only(f):
def py2(f):
def decorated_function(*args, **kwargs):
if six.PY3:
return
Expand Down

0 comments on commit b2b45b8

Please sign in to comment.