Skip to content

Commit

Permalink
Always use Model attribute definitions in create table schema. (#996)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Nov 12, 2021
1 parent 918f98a commit ece33c0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 75 deletions.
38 changes: 24 additions & 14 deletions pynamodb/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pynamodb._compat import GenericMeta
from pynamodb.constants import (
INCLUDE, ALL, KEYS_ONLY, ATTR_NAME, ATTR_TYPE, KEY_TYPE, KEY_SCHEMA,
ATTR_DEFINITIONS, META_CLASS_NAME
ATTR_DEFINITIONS, META_CLASS_NAME, PROJECTION_TYPE, NON_KEY_ATTRIBUTES,
PAY_PER_REQUEST_BILLING_MODE, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS,
)
from pynamodb.attributes import Attribute
from pynamodb.expressions.condition import Condition
Expand Down Expand Up @@ -151,27 +152,27 @@ def _get_schema(cls) -> Dict:
"""
Returns the schema for this index
"""
attr_definitions = []
schema = []
schema = {
'index_name': cls.Meta.index_name,
'key_schema': [],
'projection': {
PROJECTION_TYPE: cls.Meta.projection.projection_type,
},
}
for attr_name, attr_cls in cls._get_attributes().items():
attr_definitions.append({
'attribute_name': attr_cls.attr_name,
'attribute_type': attr_cls.attr_type
})
if attr_cls.is_hash_key:
schema.append({
schema['key_schema'].append({
ATTR_NAME: attr_cls.attr_name,
KEY_TYPE: HASH
})
elif attr_cls.is_range_key:
schema.append({
schema['key_schema'].append({
ATTR_NAME: attr_cls.attr_name,
KEY_TYPE: RANGE
})
return {
'key_schema': schema,
'attribute_definitions': attr_definitions
}
if cls.Meta.projection.non_key_attributes:
schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes
return schema

@classmethod
def _get_attributes(cls):
Expand All @@ -189,7 +190,16 @@ class GlobalSecondaryIndex(Index[_M]):
"""
A global secondary index
"""
pass

@classmethod
def _get_schema(cls) -> Dict:
schema = super()._get_schema()
if getattr(cls.Meta, 'billing_mode', None) != PAY_PER_REQUEST_BILLING_MODE:
schema['provisioned_throughput'] = {
READ_CAPACITY_UNITS: cls.Meta.read_capacity_units,
WRITE_CAPACITY_UNITS: cls.Meta.write_capacity_units
}
return schema


class LocalSecondaryIndex(Index[_M]):
Expand Down
63 changes: 19 additions & 44 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,16 +814,6 @@ def create_table(
schema['write_capacity_units'] = write_capacity_units
if billing_mode is not None:
schema['billing_mode'] = billing_mode
index_data = cls._get_indexes()
schema['global_secondary_indexes'] = index_data.get('global_secondary_indexes')
schema['local_secondary_indexes'] = index_data.get('local_secondary_indexes')
index_attrs = index_data.get('attribute_definitions')
attr_keys = [attr.get('attribute_name') for attr in schema['attribute_definitions']]
for attr in index_attrs:
attr_name = attr.get('attribute_name')
if attr_name not in attr_keys:
schema['attribute_definitions'].append(attr)
attr_keys.append(attr_name)
cls._get_connection().create_table(
**schema
)
Expand Down Expand Up @@ -868,7 +858,9 @@ def _get_schema(cls) -> Dict[str, Any]:
"""
schema: Dict[str, List] = {
'attribute_definitions': [],
'key_schema': []
'key_schema': [],
'global_secondary_indexes': [],
'local_secondary_indexes': [],
}
for attr_name, attr_cls in cls.get_attributes().items():
if attr_cls.is_hash_key or attr_cls.is_range_key:
Expand All @@ -886,41 +878,24 @@ def _get_schema(cls) -> Dict[str, Any]:
'key_type': RANGE,
'attribute_name': attr_cls.attr_name
})
return schema

@classmethod
def _get_indexes(cls):
"""
Returns a list of the secondary indexes
"""
index_data: Dict[str, List[Any]] = {
'global_secondary_indexes': [],
'local_secondary_indexes': [],
'attribute_definitions': []
}
for name, index in cls._indexes.items():
schema = index._get_schema()
idx = {
'index_name': index.Meta.index_name,
'key_schema': schema['key_schema'],
'projection': {
PROJECTION_TYPE: index.Meta.projection.projection_type,
},
}
for index in cls._indexes.values():
index_schema = index._get_schema()
if isinstance(index, GlobalSecondaryIndex):
if getattr(cls.Meta, 'billing_mode', None) != PAY_PER_REQUEST_BILLING_MODE:
idx['provisioned_throughput'] = {
READ_CAPACITY_UNITS: index.Meta.read_capacity_units,
WRITE_CAPACITY_UNITS: index.Meta.write_capacity_units
}
index_data['attribute_definitions'].extend(schema['attribute_definitions'])
if index.Meta.projection.non_key_attributes:
idx['projection'][NON_KEY_ATTRIBUTES] = index.Meta.projection.non_key_attributes
if isinstance(index, GlobalSecondaryIndex):
index_data['global_secondary_indexes'].append(idx)
schema['global_secondary_indexes'].append(index_schema)
else:
index_data['local_secondary_indexes'].append(idx)
return index_data
schema['local_secondary_indexes'].append(index_schema)
attr_names = {key_schema[ATTR_NAME]
for index_schema in (*schema['global_secondary_indexes'], *schema['local_secondary_indexes'])
for key_schema in index_schema['key_schema']}
attr_keys = {attr.get('attribute_name') for attr in schema['attribute_definitions']}
for attr_name in attr_names:
if attr_name not in attr_keys:
attr_cls = cls.get_attributes()[cls._dynamo_to_python_attr(attr_name)]
schema['attribute_definitions'].append({
'attribute_name': attr_cls.attr_name,
'attribute_type': attr_cls.attr_type
})
return schema

def _get_save_args(self, null_check: bool = True, condition: Optional[Condition] = None) -> Tuple[Iterable[Any], Dict[str, Any]]:
"""
Expand Down
34 changes: 17 additions & 17 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,27 +2359,29 @@ def fake_dynamodb(*args, **kwargs):

with patch(PATCH_METHOD, new=fake_db) as req:
IndexedModel.create_table(read_capacity_units=2, write_capacity_units=2)
params = {
'AttributeDefinitions': [
{'attribute_name': 'email', 'attribute_type': 'S'},
{'attribute_name': 'numbers', 'attribute_type': 'NS'}
],
'KeySchema': [
{'AttributeName': 'numbers', 'KeyType': 'RANGE'},
{'AttributeName': 'email', 'KeyType': 'HASH'}
]
}
schema = IndexedModel.email_index._get_schema()
args = req.call_args[0][1]
self.assert_dict_lists_equal(
args['AttributeDefinitions'],
[
{'AttributeName': 'user_name', 'AttributeType': 'S'},
{'AttributeName': 'email', 'AttributeType': 'S'},
{'AttributeName': 'numbers', 'AttributeType': 'NS'}
]
)
self.assert_dict_lists_equal(
args['GlobalSecondaryIndexes'][0]['KeySchema'],
[
{'AttributeName': 'email', 'KeyType': 'HASH'},
{'AttributeName': 'numbers', 'KeyType': 'RANGE'}
]
)
self.assertEqual(
args['GlobalSecondaryIndexes'][0]['ProvisionedThroughput'],
{
'ReadCapacityUnits': 2,
'WriteCapacityUnits': 1
}
)
self.assert_dict_lists_equal(schema['key_schema'], params['KeySchema'])
self.assert_dict_lists_equal(schema['attribute_definitions'], params['AttributeDefinitions'])

def test_local_index(self):
"""
Expand All @@ -2395,7 +2397,7 @@ def test_local_index(self):
req.return_value = LOCAL_INDEX_TABLE_DATA
LocalIndexedModel('foo')

schema = IndexedModel._get_indexes()
schema = IndexedModel._get_schema()

expected = {
'local_secondary_indexes': [
Expand Down Expand Up @@ -2426,8 +2428,7 @@ def test_local_index(self):
}
],
'attribute_definitions': [
{'attribute_type': 'S', 'attribute_name': 'email'},
{'attribute_type': 'NS', 'attribute_name': 'numbers'},
{'attribute_type': 'S', 'attribute_name': 'user_name'},
{'attribute_type': 'S', 'attribute_name': 'email'},
{'attribute_type': 'NS', 'attribute_name': 'numbers'}
]
Expand Down Expand Up @@ -2475,7 +2476,6 @@ def fake_dynamodb(*args, **kwargs):
}
schema = LocalIndexedModel.email_index._get_schema()
args = req.call_args[0][1]
self.assert_dict_lists_equal(schema['attribute_definitions'], params['AttributeDefinitions'])
self.assert_dict_lists_equal(schema['key_schema'], params['KeySchema'])
self.assertTrue('ProvisionedThroughput' not in args['LocalSecondaryIndexes'][0])

Expand Down

0 comments on commit ece33c0

Please sign in to comment.