Skip to content

Commit

Permalink
Initialize indexes in the model metaclass. (#994)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Nov 10, 2021
1 parent efa3a79 commit f868e6d
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __new__(cls, name, bases, namespace, discriminator=None):

def __init__(self, name, bases, namespace, discriminator=None) -> None:
super().__init__(name, bases, namespace, discriminator)
MetaModel._initialize_indexes(self)
cls = cast(Type['Model'], self)
for attr_name, attribute in cls.get_attributes().items():
if attribute.is_hash_key:
Expand Down Expand Up @@ -253,10 +254,6 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None:
setattr(attr_obj, 'aws_secret_access_key', None)
if not hasattr(attr_obj, 'aws_session_token'):
setattr(attr_obj, 'aws_session_token', None)
elif isinstance(attr_obj, Index):
attr_obj.Meta.model = cls
if not hasattr(attr_obj.Meta, "index_name"):
attr_obj.Meta.index_name = attr_name

# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
# so that "except Model.DoesNotExist:" would not catch other models' exceptions
Expand All @@ -267,6 +264,19 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None:
}
cls.DoesNotExist = type('DoesNotExist', (DoesNotExist, ), exception_attrs)

@staticmethod
def _initialize_indexes(cls):
"""
Initialize indexes on the class.
"""
cls._indexes = {}
for name, index in getmembers(cls, lambda o: isinstance(o, Index)):
if not hasattr(index.Meta, "model"):
index.Meta.model = cls
if not hasattr(index.Meta, "index_name"):
index.Meta.index_name = name
cls._indexes[index.Meta.index_name] = index


class Model(AttributeContainer, metaclass=MetaModel):
"""
Expand All @@ -280,13 +290,12 @@ class Model(AttributeContainer, metaclass=MetaModel):
# DynamoDB attributes
_hash_keyname: Optional[str] = None
_range_keyname: Optional[str] = None
_indexes: Optional[Dict[str, List[Any]]] = None
_connection: Optional[TableConnection] = None
_index_classes: Optional[Dict[str, Any]] = None
DoesNotExist: Type[DoesNotExist] = DoesNotExist
_version_attribute_name: Optional[str] = None

Meta: MetaProtocol
_indexes: Dict[str, Index]

def __init__(
self,
Expand Down Expand Up @@ -582,9 +591,8 @@ def count(
raise ValueError('A hash_key must be given to use filters')
return cls.describe_table().get(ITEM_COUNT)

cls._get_indexes()
if cls._index_classes and index_name:
hash_key = cls._index_classes[index_name]._hash_key_attribute().serialize(hash_key)
if index_name:
hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key)
else:
hash_key = cls._serialize_keys(hash_key)[0]

Expand Down Expand Up @@ -649,9 +657,8 @@ def query(
:param page_size: Page size of the query to DynamoDB
:param rate_limit: If set then consumed capacity will be limited to this amount per second
"""
cls._get_indexes()
if index_name and cls._index_classes:
hash_key = cls._index_classes[index_name]._hash_key_attribute().serialize(hash_key)
if index_name:
hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key)
else:
hash_key = cls._serialize_keys(hash_key)[0]

Expand Down Expand Up @@ -886,38 +893,34 @@ def _get_indexes(cls):
"""
Returns a list of the secondary indexes
"""
if cls._indexes is None:
cls._indexes = {
'global_secondary_indexes': [],
'local_secondary_indexes': [],
'attribute_definitions': []
index_data: Dict[str, List[Any]] = {
'global_secondary_indexes': [],
'local_secondary_indexes': [],
'attribute_definitions': []
}
for name, index in cls._indexes.items():
schema = index._get_schema()
idx = {
'index_name': index.Meta.index_name,
'key_schema': schema['key_schema'],
'projection': {
PROJECTION_TYPE: index.Meta.projection.projection_type,
},
}
cls._index_classes = {}
for name, index in getmembers(cls, lambda o: isinstance(o, Index)):
cls._index_classes[index.Meta.index_name] = index
schema = index._get_schema()
idx = {
'index_name': index.Meta.index_name,
'key_schema': schema.get('key_schema'),
'projection': {
PROJECTION_TYPE: index.Meta.projection.projection_type,
},

}
if isinstance(index, GlobalSecondaryIndex):
if getattr(cls.Meta, 'billing_mode', None) != PAY_PER_REQUEST_BILLING_MODE:
idx['provisioned_throughput'] = {
READ_CAPACITY_UNITS: index.Meta.read_capacity_units,
WRITE_CAPACITY_UNITS: index.Meta.write_capacity_units
}
cls._indexes['attribute_definitions'].extend(schema.get('attribute_definitions'))
if index.Meta.projection.non_key_attributes:
idx['projection'][NON_KEY_ATTRIBUTES] = index.Meta.projection.non_key_attributes
if isinstance(index, GlobalSecondaryIndex):
cls._indexes['global_secondary_indexes'].append(idx)
else:
cls._indexes['local_secondary_indexes'].append(idx)
return cls._indexes
if isinstance(index, GlobalSecondaryIndex):
if getattr(cls.Meta, 'billing_mode', None) != PAY_PER_REQUEST_BILLING_MODE:
idx['provisioned_throughput'] = {
READ_CAPACITY_UNITS: index.Meta.read_capacity_units,
WRITE_CAPACITY_UNITS: index.Meta.write_capacity_units
}
index_data['attribute_definitions'].extend(schema['attribute_definitions'])
if index.Meta.projection.non_key_attributes:
idx['projection'][NON_KEY_ATTRIBUTES] = index.Meta.projection.non_key_attributes
if isinstance(index, GlobalSecondaryIndex):
index_data['global_secondary_indexes'].append(idx)
else:
index_data['local_secondary_indexes'].append(idx)
return index_data

def _get_save_args(self, null_check: bool = True, condition: Optional[Condition] = None) -> Tuple[Iterable[Any], Dict[str, Any]]:
"""
Expand Down

0 comments on commit f868e6d

Please sign in to comment.