Skip to content

Commit

Permalink
Support model class inheritance. Fixes #164 (#862)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Oct 8, 2020
1 parent 9d57373 commit d2be48c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
27 changes: 11 additions & 16 deletions pynamodb/models.py
Expand Up @@ -12,7 +12,7 @@
from pynamodb.expressions.update import Action
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError
from pynamodb.attributes import (
Attribute, AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute
AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute
)
from pynamodb.connection.table import TableConnection
from pynamodb.expressions.condition import Condition
Expand Down Expand Up @@ -151,10 +151,6 @@ def commit(self) -> None:
unprocessed_items = data.get(UNPROCESSED_ITEMS, {}).get(self.model.Meta.table_name)


class DefaultMeta(object):
pass


class MetaModel(AttributeContainerMeta):
table_name: str
read_capacity_units: Optional[int]
Expand Down Expand Up @@ -184,17 +180,26 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
cls = cast(Type['Model'], self)
for attr_name, attribute in cls.get_attributes().items():
if attribute.is_hash_key:
if cls._hash_keyname and cls._hash_keyname != attr_name:
raise ValueError(f"{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}")
cls._hash_keyname = attr_name
if attribute.is_range_key:
if cls._range_keyname and cls._range_keyname != attr_name:
raise ValueError(f"{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}")
cls._range_keyname = attr_name
if isinstance(attribute, VersionAttribute):
if cls._version_attribute_name:
if cls._version_attribute_name and cls._version_attribute_name != attr_name:
raise ValueError(
"The model has more than one Version attribute: {}, {}"
.format(cls._version_attribute_name, attr_name)
)
cls._version_attribute_name = attr_name

ttl_attr_names = [name for name, attr in cls.get_attributes().items() if isinstance(attr, TTLAttribute)]
if len(ttl_attr_names) > 1:
raise ValueError("{} has more than one TTL attribute: {}".format(
cls.__name__, ", ".join(ttl_attr_names)))

if isinstance(attrs, dict):
for attr_name, attr_obj in attrs.items():
if attr_name == META_CLASS_NAME:
Expand Down Expand Up @@ -226,16 +231,6 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
attr_obj.Meta.model = cls
if not hasattr(attr_obj.Meta, "index_name"):
attr_obj.Meta.index_name = attr_name
elif isinstance(attr_obj, Attribute):
if attr_obj.attr_name is None:
attr_obj.attr_name = attr_name

ttl_attr_names = [name for name, attr_obj in attrs.items() if isinstance(attr_obj, TTLAttribute)]
if len(ttl_attr_names) > 1:
raise ValueError("The model has more than one TTL attribute: {}".format(", ".join(ttl_attr_names)))

if META_CLASS_NAME not in attrs:
setattr(cls, META_CLASS_NAME, DefaultMeta)

# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
# so that "except Model.DoesNotExist:" would not catch other models' exceptions
Expand Down
40 changes: 37 additions & 3 deletions tests/test_model.py
Expand Up @@ -2,7 +2,6 @@
Test model API
"""
import base64
import random
import json
import copy
from datetime import datetime
Expand All @@ -14,9 +13,7 @@
import pytest

from .deep_eq import deep_eq
from pynamodb.util import snake_to_camel_case
from pynamodb.exceptions import DoesNotExist, TableError, PutError, AttributeDeserializationError
from pynamodb.types import RANGE
from pynamodb.constants import (
ITEM, STRING, ALL, KEYS_ONLY, INCLUDE, REQUEST_ITEMS, UNPROCESSED_KEYS, CAMEL_COUNT,
RESPONSES, KEYS, ITEMS, LAST_EVALUATED_KEY, EXCLUSIVE_START_KEY, ATTRIBUTES, BINARY,
Expand Down Expand Up @@ -2424,6 +2421,17 @@ def test_old_style_model_exception(self):
with self.assertRaises(AttributeError):
OldStyleModel.exists()

def test_no_table_name_exception(self):
"""
Display warning for Models without table names
"""
class MissingTableNameModel(Model):
class Meta:
pass
user_name = UnicodeAttribute(hash_key=True)
with self.assertRaises(AttributeError):
MissingTableNameModel.exists()

def _get_office_employee(self):
justin = Person(
fname='Justin',
Expand Down Expand Up @@ -3214,6 +3222,24 @@ def test_deserialized_with_ttl(self):
def test_deserialized_with_invalid_type(self):
self.assertRaises(AttributeDeserializationError, TTLModel.from_raw_data, {'my_ttl': {'S': '1546300800'}})

def test_multiple_hash_keys(self):
with self.assertRaises(ValueError):
class BadHashKeyModel(Model):
class Meta:
table_name = 'BadHashKeyModel'

foo = UnicodeAttribute(hash_key=True)
bar = UnicodeAttribute(hash_key=True)

def test_multiple_range_keys(self):
with self.assertRaises(ValueError):
class BadRangeKeyModel(Model):
class Meta:
table_name = 'BadRangeKeyModel'

foo = UnicodeAttribute(range_key=True)
bar = UnicodeAttribute(range_key=True)

def test_multiple_version_attributes(self):
with self.assertRaises(ValueError):
class BadVersionedModel(Model):
Expand All @@ -3222,3 +3248,11 @@ class Meta:

version = VersionAttribute()
another_version = VersionAttribute()

def test_inherit_metaclass(self):
class ParentModel(Model):
class Meta:
table_name = 'foo'
class ChildModel(ParentModel):
pass
self.assertEqual(ParentModel.Meta.table_name, ChildModel.Meta.table_name)

0 comments on commit d2be48c

Please sign in to comment.