Skip to content

Commit

Permalink
Simplify map attribute deserialization. (#839)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Sep 10, 2020
1 parent 29179cd commit 14f81c7
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 94 deletions.
81 changes: 34 additions & 47 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from dateutil.tz import tzutc
from inspect import getfullargspec
from inspect import getmembers
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Text, TypeVar, Type, Union, Set, overload
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, TypeVar, Type, Union, Set, overload
from typing import TYPE_CHECKING

from pynamodb._compat import GenericMeta
from pynamodb.constants import (
BINARY, BINARY_SET, BOOLEAN, DATETIME_FORMAT, DEFAULT_ENCODING,
LIST, MAP, NULL, NUMBER, NUMBER_SET, STRING, STRING_SET
)
from pynamodb.exceptions import AttributeDeserializationError
from pynamodb.expressions.operand import Path


Expand Down Expand Up @@ -71,12 +72,11 @@ def __init__(
self.is_hash_key = hash_key
self.is_range_key = range_key

# AttributeContainerMeta._initialize_attributes will ensure this is a
# string
# AttributeContainerMeta._initialize_attributes will ensure this is a string
self.attr_path: List[str] = [attr_name] # type: ignore

@property
def attr_name(self) -> Optional[str]:
def attr_name(self) -> str:
return self.attr_path[-1]

@attr_name.setter
Expand Down Expand Up @@ -120,8 +120,10 @@ def deserialize(self, value: Any) -> Any:
"""
return value

def get_value(self, value: Any) -> Any:
return value.get(self.attr_type)
def get_value(self, value: Dict[str, Any]) -> Any:
if self.attr_type not in value:
raise AttributeDeserializationError(self.attr_name, self.attr_type)
return value[self.attr_type]

def __iter__(self):
# Because we define __getitem__ below for condition expression support
Expand Down Expand Up @@ -278,7 +280,7 @@ def get_attributes(cls) -> Dict[str, Attribute]:
return cls._attributes # type: ignore

@classmethod
def _dynamo_to_python_attr(cls, dynamo_key: str) -> Optional[str]:
def _dynamo_to_python_attr(cls, dynamo_key: str) -> str:
"""
Convert a DynamoDB attribute name to the internal Python name.
Expand Down Expand Up @@ -311,6 +313,18 @@ def _set_attributes(self, **attributes: Attribute) -> None:
raise ValueError("Attribute {} specified does not exist".format(attr_name))
setattr(self, attr_name, attr_value)

def _deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
"""
Sets attributes sent back from DynamoDB on this object
"""
self.attribute_values = {}
self._set_defaults(_user_instantiated=False)
for name, attr in self.get_attributes().items():
attribute_value = attribute_values.get(attr.attr_name)
if attribute_value and NULL not in attribute_value:
value = attr.deserialize(attr.get_value(attribute_value))
setattr(self, name, value)

def __eq__(self, other: Any) -> bool:
# This is required so that MapAttribute can call this method.
return self is other
Expand Down Expand Up @@ -803,8 +817,7 @@ def serialize(self, values):
continue

# If this is a subclassed MapAttribute, there may be an alternate attr name
attr = self.get_attributes().get(k)
attr_name = attr.attr_name if attr else k
attr_name = attr_class.attr_name if not self.is_raw() else k

serialized = attr_class.serialize(v)
if self._should_skip(serialized):
Expand All @@ -819,23 +832,17 @@ def deserialize(self, values):
"""
Decode as a dict.
"""
deserialized_dict: Dict[str, Any] = dict()
for k in values:
v = values[k]
attr_value = _get_value_for_deserialize(v)
key = self._dynamo_to_python_attr(k)
attr_class = self._get_deserialize_class(key, v)
if key is None or attr_class is None:
continue
deserialized_value = None
if attr_value is not None:
deserialized_value = attr_class.deserialize(attr_value)

deserialized_dict[key] = deserialized_value

# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
if not self.is_raw():
return type(self)(**deserialized_dict)
# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
instance = type(self)()
instance._deserialize(values)
return instance

deserialized_dict: Dict[str, Any] = dict()
for k, v in values.items():
attr_type, attr_value = next(iter(v.items()))
attr_class = DESERIALIZE_CLASS_MAP[attr_type]
deserialized_dict[k] = attr_class.deserialize(attr_value)
return deserialized_dict

@classmethod
Expand All @@ -850,7 +857,7 @@ def as_dict(self):

def _should_skip(self, value):
# Continue to serialize NULL values in "raw" map attributes for backwards compatibility.
# This special case behavior for "raw" attribtues should be removed in the future.
# This special case behavior for "raw" attributes should be removed in the future.
return not self.is_raw() and value is None

@classmethod
Expand All @@ -859,32 +866,12 @@ def _get_serialize_class(cls, key, value):
return cls.get_attributes().get(key)
return _get_class_for_serialize(value)

@classmethod
def _get_deserialize_class(cls, key, value):
if not cls.is_raw():
return cls.get_attributes().get(key)
return _get_class_for_deserialize(value)


def _get_value_for_deserialize(value):
key = next(iter(value.keys()))
if key == NULL:
return None
return value[key]


def _get_class_for_deserialize(value):
value_type = next(iter(value.keys()))
if value_type not in DESERIALIZE_CLASS_MAP:
raise ValueError('Unknown value: ' + str(value))
return DESERIALIZE_CLASS_MAP[value_type]


def _get_class_for_serialize(value):
if value is None:
return NullAttribute()
if isinstance(value, MapAttribute):
return type(value)()
return value
value_type = type(value)
if value_type not in SERIALIZE_CLASS_MAP:
raise ValueError('Unknown value: {}'.format(value_type))
Expand Down
4 changes: 2 additions & 2 deletions pynamodb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ class AttributeDeserializationError(TypeError):
"""
Raised when attribute type is invalid
"""
def __init__(self, attr_name: str):
msg = "Deserialization error on `{}`".format(attr_name)
def __init__(self, attr_name: str, attr_type: str):
msg = "Cannot deserialize '{}' attribute from type: {}".format(attr_name, attr_type)
super(AttributeDeserializationError, self).__init__(msg)


Expand Down
29 changes: 4 additions & 25 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Tuple, Union, cast

from pynamodb.expressions.update import Action
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError, AttributeDeserializationError
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError
from pynamodb.attributes import (
Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute, VersionAttribute
)
Expand Down Expand Up @@ -546,16 +546,9 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T:
if data is None:
raise ValueError("Received no data to construct object")

attributes: Dict[str, Any] = {}
for name, value in data.items():
attr_name = cls._dynamo_to_python_attr(name)
attr = cls.get_attributes().get(attr_name, None) # type: ignore
if attr:
try:
attributes[attr_name] = attr.deserialize(attr.get_value(value)) # type: ignore
except TypeError as e:
raise AttributeDeserializationError(attr_name=attr_name) from e # type: ignore
return cls(_user_instantiated=False, **attributes)
model = cls(_user_instantiated=False)
model._deserialize(data)
return model

@classmethod
def count(
Expand Down Expand Up @@ -1114,20 +1107,6 @@ def _get_connection(cls) -> TableConnection:
aws_session_token=cls.Meta.aws_session_token)
return cls._connection

def _deserialize(self, attrs):
"""
Sets attributes sent back from DynamoDB on this object
:param attrs: A dictionary of attributes to update this item with.
"""
for name, attr in self.get_attributes().items():
value = attrs.get(attr.attr_name, None)
if value is not None:
value = value.get(attr.attr_type, None)
if value is not None:
value = attr.deserialize(value)
setattr(self, name, value)

def _serialize(self, attr_map=False, null_check=True) -> Dict[str, Any]:
"""
Serializes all model attributes for use with DynamoDB
Expand Down
12 changes: 6 additions & 6 deletions tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
'N': '31'
},
'is_dude': {
'N': '1'
'BOOL': True
}
}
},
Expand Down Expand Up @@ -641,7 +641,7 @@
'N': '31'
},
'is_dude': {
'N': '1'
'BOOL': True
}
}
},
Expand Down Expand Up @@ -677,7 +677,7 @@
'N': '30'
},
'is_dude': {
'N': '1'
'BOOL': True
}
}
},
Expand Down Expand Up @@ -713,7 +713,7 @@
'N': '32'
},
'is_dude': {
'N': '1'
'BOOL': True
}
}
},
Expand Down Expand Up @@ -749,7 +749,7 @@
'N': '30'
},
'is_dude': {
'N': '0'
'BOOL': False
}
}
},
Expand Down Expand Up @@ -926,7 +926,7 @@
'N': '31'
},
'is_dude': {
'N': '1'
'BOOL': True
}
}

Expand Down
15 changes: 1 addition & 14 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pynamodb.attributes import (
BinarySetAttribute, BinaryAttribute, NumberSetAttribute, NumberAttribute,
UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, MapAttribute,
ListAttribute, JSONAttribute, TTLAttribute, _get_value_for_deserialize, _fast_parse_utc_datestring,
ListAttribute, JSONAttribute, TTLAttribute, _fast_parse_utc_datestring,
VersionAttribute)
from pynamodb.constants import (
DATETIME_FORMAT, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET,
Expand Down Expand Up @@ -890,19 +890,6 @@ class MyModel(Model):
assert mid_map_b_map_attr.attr_path == ['dyn_out_map', 'mid_map_b', 'dyn_in_map_b', 'dyn_map_attr']


class TestValueDeserialize:
def test__get_value_for_deserialize(self):
expected = '3'
data = {'N': '3'}
actual = _get_value_for_deserialize(data)
assert expected == actual

def test__get_value_for_deserialize_null(self):
data = {'NULL': 'True'}
actual = _get_value_for_deserialize(data)
assert actual is None


class TestListAttribute:

def test_untyped_list(self):
Expand Down

0 comments on commit 14f81c7

Please sign in to comment.