Skip to content

Commit

Permalink
Require Attribute default to be immutable or callable (#1096)
Browse files Browse the repository at this point in the history
Making it harder to accidentally mutate a default.
  • Loading branch information
ikonst committed Nov 4, 2022
1 parent 43a303b commit 0a986dd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
50 changes: 48 additions & 2 deletions pynamodb/attributes.py
Expand Up @@ -53,10 +53,47 @@

_A = TypeVar('_A', bound='Attribute')

_IMMUTABLE_TYPES = (str, int, float, datetime, timedelta, bytes, bool, tuple, frozenset, type(None))
_IMMUTABLE_TYPE_NAMES = ', '.join(map(lambda x: x.__name__, _IMMUTABLE_TYPES))


class Attribute(Generic[_T]):
"""
An attribute of a model
An attribute of a model or an index.
:param hash_key: If `True`, this attribute is a model's or an index's hash key (partition key).
:param range_key: If `True`, this attribute is a model's or an index's range key (sort key).
:param null: If `True`, a `None` value would be considered valid and would result in the attribute
not being set in the underlying DynamoDB item. If `False` (default), an exception will be raised when
the attribute is persisted with a `None` value.
.. note::
This is different from :class:`pynamodb.attributes.NullAttribute`, which manifests in a `NULL`-typed
DynamoDB attribute value.
:param default: A default value that will be assigned in new models (when they are initialized)
and existing models (when they are loaded).
.. note::
Starting with PynamoDB 6.0, the default must be either an immutable value (of one of the built-in
immutable types) or a callable. This prevents a common class of errors caused by unintentionally mutating
the default value. A simple workaround is to pass an initializer (e.g. change :code:`default={}` to
:code:`default=dict`) or wrap in a lambda (e.g. change :code:`default={'foo': 'bar'}` to
:code:`default=lambda: {'foo': 'bar'}`).
:param default_for_new: Like `default`, but used only for new models. Use this to assign a default
for new models that you don't want to apply to existing models when they are loaded and then re-saved.
.. note::
Starting with PynamoDB 6.0, the default must be either an immutable value (of one of the built-in
immutable types) or a callable.
:param attr_name: The name that is used for the attribute in the underlying DynamoDB item;
use this to assign a "pythonic" name that is different from the persisted name, i.e.
.. code-block:: python
number_of_threads = NumberAttribute(attr_name='thread_count')
"""
attr_type: str
null = False
Expand All @@ -72,8 +109,17 @@ def __init__(
) -> None:
if default and default_for_new:
raise ValueError("An attribute cannot have both default and default_for_new parameters")
if not callable(default) and not isinstance(default, _IMMUTABLE_TYPES):
raise ValueError(
f"An attribute's 'default' must be immutable ({_IMMUTABLE_TYPE_NAMES}) or a callable "
"(see https://pynamodb.readthedocs.io/en/latest/api.html#pynamodb.attributes.Attribute)"
)
if not callable(default_for_new) and not isinstance(default_for_new, _IMMUTABLE_TYPES):
raise ValueError(
f"An attribute's 'default_for_new' must be immutable ({_IMMUTABLE_TYPE_NAMES}) or a callable "
"(see https://pynamodb.readthedocs.io/en/latest/api.html#pynamodb.attributes.Attribute)"
)
self.default = default
# This default is only set for new objects (ie: it's not set for re-saved objects)
self.default_for_new = default_for_new

if null is not None:
Expand Down
39 changes: 27 additions & 12 deletions tests/test_attributes.py
Expand Up @@ -15,7 +15,7 @@
from pynamodb.attributes import (
BinarySetAttribute, BinaryAttribute, DynamicMapAttribute, NumberSetAttribute, NumberAttribute,
UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, MapAttribute, NullAttribute,
ListAttribute, JSONAttribute, TTLAttribute, VersionAttribute)
ListAttribute, JSONAttribute, TTLAttribute, VersionAttribute, Attribute)
from pynamodb.constants import (
DATETIME_FORMAT, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET,
BINARY, BOOLEAN,
Expand Down Expand Up @@ -49,7 +49,7 @@ class CustomAttrMap(MapAttribute):


class DefaultsMap(MapAttribute):
map_field = MapAttribute(default={})
map_field = MapAttribute(default=dict)
string_set_field = UnicodeSetAttribute(null=True)


Expand Down Expand Up @@ -125,6 +125,21 @@ def test_json_attr(self):
assert self.instance.json_attr == {'foo': 'bar', 'bar': 42}


class TestDefault:
def test_default(self):
Attribute(default='test')
Attribute(default_for_new='test')

with pytest.raises(ValueError, match="'default' must be immutable (.*) or a callable"):
Attribute(default=[])

with pytest.raises(ValueError, match="'default_for_new' must be immutable (.*) or a callable"):
Attribute(default_for_new=[])

Attribute(default=list)
Attribute(default_for_new=list)


class TestUTCDateTimeAttribute:
"""
Tests UTCDateTime attributes
Expand Down Expand Up @@ -260,8 +275,8 @@ def test_binary_set_attribute(self):
attr = BinarySetAttribute()
assert attr is not None

attr = BinarySetAttribute(default={b'foo', b'bar'})
assert attr.default == {b'foo', b'bar'}
attr = BinarySetAttribute(default=lambda: {b'foo', b'bar'})
assert attr.default() == {b'foo', b'bar'}


class TestNumberAttribute:
Expand Down Expand Up @@ -320,8 +335,8 @@ def test_number_set_attribute(self):
attr = NumberSetAttribute()
assert attr is not None

attr = NumberSetAttribute(default={1, 2})
assert attr.default == {1, 2}
attr = NumberSetAttribute(default=lambda: {1, 2})
assert attr.default() == {1, 2}


class TestUnicodeAttribute:
Expand Down Expand Up @@ -416,8 +431,8 @@ def test_unicode_set_attribute(self):
attr = UnicodeSetAttribute()
assert attr is not None
assert attr.attr_type == STRING_SET
attr = UnicodeSetAttribute(default={'foo', 'bar'})
assert attr.default == {'foo', 'bar'}
attr = UnicodeSetAttribute(default=lambda: {'foo', 'bar'})
assert attr.default() == {'foo', 'bar'}


class TestBooleanAttribute:
Expand Down Expand Up @@ -526,8 +541,8 @@ def test_json_attribute(self):
assert attr is not None

assert attr.attr_type == STRING
attr = JSONAttribute(default={})
assert attr.default == {}
attr = JSONAttribute(default=lambda: {})
assert attr.default() == {}

def test_json_serialize(self):
"""
Expand Down Expand Up @@ -1018,7 +1033,7 @@ def __eq__(self, other):

inp = [person1, person2]

list_attribute = ListAttribute(default=[], of=Person)
list_attribute = ListAttribute(default=list, of=Person)
serialized = list_attribute.serialize(inp)
deserialized = list_attribute.deserialize(serialized)
assert sorted(deserialized) == sorted(inp)
Expand All @@ -1042,7 +1057,7 @@ def __eq__(self, other):

inp = [attribute1, attribute2]

list_attribute = ListAttribute(default=[], of=CustomMapAttribute)
list_attribute = ListAttribute(default=list, of=CustomMapAttribute)
serialized = list_attribute.serialize(inp)
deserialized = list_attribute.deserialize(serialized)

Expand Down

0 comments on commit 0a986dd

Please sign in to comment.