Skip to content

Commit

Permalink
create_table should include polymorphic indexes (#1107)
Browse files Browse the repository at this point in the history
With polymorphism, an index's key attribute might be defined only in a derived model (e.g. if it pertains only to a certain kind of model — for other models, it'd be a sparse index), and so it makes sense to add that index only to the derived model, e.g.
```python
class BaseModel(Model):
    ...
    cls = DescriminatorAttribute()

class DerivedIndex(GlobalSecondaryIndex):
    class Meta:
        index_name = 'derived'
        ...
    ham = UnicodeAttribute(hash_key=True)

class DerivedModel(BaseModel, discriminator='spam'):
    ham  = UnicodeAttribute()
    index = DerivedIndex()
```

A common pattern in a development environment is to use `Model.create_table`. For a polymorphic model, given you have to pick _one_ model to invoke `create_table` on, it only serves right that you invoke it on the base model.

Before this change, the base model's "schema"  didn't include all the derived models' indexes. With this change, it will.
  • Loading branch information
ikonst committed Nov 16, 2022
1 parent 4e7054d commit 2ec2f35
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 135 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Expand Up @@ -73,7 +73,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e.[signals] sphinx sphinx-rtd-theme
python -m pip install -r docs/requirements.txt
- name: Build docs
run: |
sphinx-build -W docs /tmp/docs-build
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
@@ -1 +1,2 @@
.[signals]
sphinx-rtd-theme==0.4.3
51 changes: 51 additions & 0 deletions pynamodb/_schema.py
@@ -0,0 +1,51 @@
import sys
from typing import Dict
from typing import List

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict

if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired


class SchemaAttrDefinition(TypedDict):
AttributeName: str
AttributeType: str


class KeySchema(TypedDict):
AttributeName: str
KeyType: str


class Projection(TypedDict):
ProjectionType: str
NonKeyAttributes: NotRequired[List[str]]


class IndexSchema(TypedDict):
index_name: str
key_schema: List[Dict[str, str]]
projection: Dict[str, str]
attribute_definitions: List[SchemaAttrDefinition]


class ProvisionedThroughput(TypedDict, total=False):
ReadCapacityUnits: int
WriteCapacityUnits: int


class GlobalSecondaryIndexSchema(IndexSchema):
provisioned_throughput: ProvisionedThroughput


class ModelSchema(TypedDict):
attribute_definitions: List[SchemaAttrDefinition]
key_schema: List[KeySchema]
global_secondary_indexes: List[GlobalSecondaryIndexSchema]
local_secondary_indexes: List[IndexSchema]
2 changes: 1 addition & 1 deletion pynamodb/attributes.py
Expand Up @@ -494,7 +494,7 @@ def register_class(self, cls: type, discriminator: Any):

self._discriminator_map[discriminator] = cls

def get_registered_subclasses(self, cls: type) -> List[type]:
def get_registered_subclasses(self, cls: Type[_T]) -> List[Type[_T]]:
return [k for k in self._class_map.keys() if issubclass(k, cls)]

def get_discriminator(self, cls: type) -> Optional[Any]:
Expand Down
15 changes: 10 additions & 5 deletions pynamodb/constants.py
@@ -1,6 +1,11 @@
"""
Pynamodb constants
"""
import sys
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final

# Operations
TRANSACT_WRITE_ITEMS = 'TransactWriteItems'
Expand Down Expand Up @@ -41,8 +46,8 @@
TABLE_STATUS = 'TableStatus'
TABLE_NAME = 'TableName'
KEY_SCHEMA = 'KeySchema'
ATTR_NAME = 'AttributeName'
ATTR_TYPE = 'AttributeType'
ATTR_NAME: Final = 'AttributeName'
ATTR_TYPE: Final = 'AttributeType'
ITEM_COUNT = 'ItemCount'
CAMEL_COUNT = 'Count'
PUT_REQUEST = 'PutRequest'
Expand All @@ -51,7 +56,7 @@
TABLE_KEY = 'Table'
RESPONSES = 'Responses'
RANGE_KEY = 'RangeKey'
KEY_TYPE = 'KeyType'
KEY_TYPE: Final = 'KeyType'
UPDATE = 'Update'
SELECT = 'Select'
ACTIVE = 'ACTIVE'
Expand Down Expand Up @@ -100,8 +105,8 @@

# Create Table arguments
PROVISIONED_THROUGHPUT = 'ProvisionedThroughput'
READ_CAPACITY_UNITS = 'ReadCapacityUnits'
WRITE_CAPACITY_UNITS = 'WriteCapacityUnits'
READ_CAPACITY_UNITS: Final = 'ReadCapacityUnits'
WRITE_CAPACITY_UNITS: Final = 'WriteCapacityUnits'
BILLING_MODE = 'BillingMode'

# Attribute Types
Expand Down
60 changes: 41 additions & 19 deletions pynamodb/indexes.py
Expand Up @@ -5,16 +5,17 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from typing import TYPE_CHECKING

from pynamodb._schema import IndexSchema, GlobalSecondaryIndexSchema
from pynamodb._schema import ModelSchema
from pynamodb.constants import (
INCLUDE, ALL, KEYS_ONLY, ATTR_NAME, ATTR_TYPE, KEY_TYPE, KEY_SCHEMA,
ATTR_DEFINITIONS, PROJECTION_TYPE, NON_KEY_ATTRIBUTES,
INCLUDE, ALL, KEYS_ONLY, ATTR_NAME, ATTR_TYPE, KEY_TYPE,
PROJECTION_TYPE, NON_KEY_ATTRIBUTES,
READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS,
)
from pynamodb.attributes import Attribute
from pynamodb.expressions.condition import Condition
from pynamodb.pagination import ResultIterator
from pynamodb.types import HASH, RANGE

if TYPE_CHECKING:
from pynamodb.models import Model

Expand Down Expand Up @@ -136,28 +137,32 @@ def _hash_key_attribute(cls):
if attr_cls.is_hash_key:
return attr_cls

def _update_model_schema(self, schema: ModelSchema) -> None:
raise NotImplementedError

@classmethod
def _get_schema(cls) -> Dict:
def _get_schema(cls) -> IndexSchema:
"""
Returns the schema for this index
"""
schema = {
schema: IndexSchema = {
'index_name': cls.Meta.index_name,
'key_schema': [],
'projection': {
PROJECTION_TYPE: cls.Meta.projection.projection_type,
},
'attribute_definitions': [],
}

for attr_cls in cls.Meta.attributes.values():
if attr_cls.is_hash_key:
schema['key_schema'].append({
if attr_cls.is_hash_key or attr_cls.is_range_key:
schema['attribute_definitions'].append({
ATTR_NAME: attr_cls.attr_name,
KEY_TYPE: HASH
ATTR_TYPE: attr_cls.attr_type,
})
elif attr_cls.is_range_key:
schema['key_schema'].append({
ATTR_NAME: attr_cls.attr_name,
KEY_TYPE: RANGE
KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE,
})
if cls.Meta.projection.non_key_attributes:
schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes
Expand All @@ -168,24 +173,41 @@ class GlobalSecondaryIndex(Index[_M]):
"""
A global secondary index
"""

@classmethod
def _get_schema(cls) -> Dict:
schema = super()._get_schema()
provisioned_throughput = {}
def _update_model_schema(cls, schema: ModelSchema) -> None:
index_schema: GlobalSecondaryIndexSchema = {
**cls._get_schema(), # type:ignore[misc] # https://github.com/python/mypy/pull/13353
'provisioned_throughput': {},
}

if hasattr(cls.Meta, 'read_capacity_units'):
provisioned_throughput[READ_CAPACITY_UNITS] = cls.Meta.read_capacity_units
index_schema['provisioned_throughput'][READ_CAPACITY_UNITS] = cls.Meta.read_capacity_units
if hasattr(cls.Meta, 'write_capacity_units'):
provisioned_throughput[WRITE_CAPACITY_UNITS] = cls.Meta.write_capacity_units
schema['provisioned_throughput'] = provisioned_throughput
return schema
index_schema['provisioned_throughput'][WRITE_CAPACITY_UNITS] = cls.Meta.write_capacity_units

schema['global_secondary_indexes'].append(index_schema)
# With polymorphism, indexes can use the same attribute, e.g. index1 on (thread_id, created_at)
# and index2 on (thread_id, updated_at). We need to deduplicate.
for attr_def in index_schema['attribute_definitions']:
if attr_def not in schema['attribute_definitions']:
schema['attribute_definitions'].append(attr_def)


class LocalSecondaryIndex(Index[_M]):
"""
A local secondary index
"""
pass

@classmethod
def _update_model_schema(cls, schema: ModelSchema) -> None:
index_schema = cls._get_schema()
schema['local_secondary_indexes'].append(index_schema)
# With polymorphism, indexes can use the same attribute, e.g. index1 on (thread_id, created_at)
# and index2 on (thread_id, updated_at). We need to deduplicate.
for attr_def in index_schema['attribute_definitions']:
if attr_def not in schema['attribute_definitions']:
schema['attribute_definitions'].append(attr_def)



class Projection:
Expand Down
86 changes: 39 additions & 47 deletions pynamodb/models.py
Expand Up @@ -25,6 +25,7 @@
from typing import Union
from typing import cast

from pynamodb._schema import ModelSchema
from pynamodb.connection.base import MetaTable

if sys.version_info >= (3, 8):
Expand All @@ -41,23 +42,21 @@
from pynamodb.connection.table import TableConnection
from pynamodb.expressions.condition import Condition
from pynamodb.types import HASH, RANGE
from pynamodb.indexes import Index, GlobalSecondaryIndex, LocalSecondaryIndex
from pynamodb.indexes import Index
from pynamodb.pagination import ResultIterator
from pynamodb.settings import get_settings_value, OperationSettings
from pynamodb import constants
from pynamodb.constants import (
ATTR_DEFINITIONS, ATTR_NAME, ATTR_TYPE, KEY_SCHEMA,
KEY_TYPE, ITEM, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS,
RANGE_KEY, ATTRIBUTES, PUT, DELETE, RESPONSES,
INDEX_NAME, PROVISIONED_THROUGHPUT, PROJECTION, ALL_NEW,
GLOBAL_SECONDARY_INDEXES, LOCAL_SECONDARY_INDEXES, KEYS,
PROJECTION_TYPE, NON_KEY_ATTRIBUTES,
TABLE_STATUS, ACTIVE, RETURN_VALUES, BATCH_GET_PAGE_LIMIT,
ATTR_NAME, ATTR_TYPE,
KEY_TYPE, ITEM,
ATTRIBUTES, PUT, DELETE, RESPONSES,
ALL_NEW,
KEYS,
TABLE_STATUS, ACTIVE, BATCH_GET_PAGE_LIMIT,
UNPROCESSED_KEYS, PUT_REQUEST, DELETE_REQUEST,
BATCH_WRITE_PAGE_LIMIT,
META_CLASS_NAME, REGION, HOST, NULL,
COUNT, ITEM_COUNT, KEY, UNPROCESSED_ITEMS, STREAM_VIEW_TYPE,
STREAM_SPECIFICATION, STREAM_ENABLED, BILLING_MODE, PAY_PER_REQUEST_BILLING_MODE, TAGS, TABLE_NAME
COUNT, ITEM_COUNT, KEY, UNPROCESSED_ITEMS,
)
from pynamodb.util import attribute_value_to_json
from pynamodb.util import json_to_attribute_value
Expand Down Expand Up @@ -794,27 +793,33 @@ def create_table(
"""
if not cls.exists():
schema = cls._get_schema()
operation_kwargs: Dict[str, Any] = {
'attribute_definitions': schema['attribute_definitions'],
'key_schema': schema['key_schema'],
'global_secondary_indexes': schema['global_secondary_indexes'],
'local_secondary_indexes': schema['local_secondary_indexes'],
}
if hasattr(cls.Meta, 'read_capacity_units'):
schema['read_capacity_units'] = cls.Meta.read_capacity_units
operation_kwargs['read_capacity_units'] = cls.Meta.read_capacity_units
if hasattr(cls.Meta, 'write_capacity_units'):
schema['write_capacity_units'] = cls.Meta.write_capacity_units
operation_kwargs['write_capacity_units'] = cls.Meta.write_capacity_units
if hasattr(cls.Meta, 'stream_view_type'):
schema['stream_specification'] = {
operation_kwargs['stream_specification'] = {
'stream_enabled': True,
'stream_view_type': cls.Meta.stream_view_type
}
if hasattr(cls.Meta, 'billing_mode'):
schema['billing_mode'] = cls.Meta.billing_mode
operation_kwargs['billing_mode'] = cls.Meta.billing_mode
if hasattr(cls.Meta, 'tags'):
schema['tags'] = cls.Meta.tags
operation_kwargs['tags'] = cls.Meta.tags
if read_capacity_units is not None:
schema['read_capacity_units'] = read_capacity_units
operation_kwargs['read_capacity_units'] = read_capacity_units
if write_capacity_units is not None:
schema['write_capacity_units'] = write_capacity_units
operation_kwargs['write_capacity_units'] = write_capacity_units
if billing_mode is not None:
schema['billing_mode'] = billing_mode
operation_kwargs['billing_mode'] = billing_mode
cls._get_connection().create_table(
**schema
**operation_kwargs
)
if wait:
while True:
Expand Down Expand Up @@ -850,11 +855,12 @@ def update_ttl(cls, ignore_update_ttl_errors: bool) -> None:

# Private API below
@classmethod
def _get_schema(cls) -> Dict[str, Any]:
def _get_schema(cls) -> ModelSchema:
"""
Returns the schema for this table
"""
schema: Dict[str, List] = {

schema: ModelSchema = {
'attribute_definitions': [],
'key_schema': [],
'global_secondary_indexes': [],
Expand All @@ -866,35 +872,21 @@ def _get_schema(cls) -> Dict[str, Any]:
ATTR_NAME: attr_cls.attr_name,
ATTR_TYPE: attr_cls.attr_type
})
if attr_cls.is_hash_key:
schema['key_schema'].append({
KEY_TYPE: HASH,
ATTR_NAME: attr_cls.attr_name
})
elif attr_cls.is_range_key:
schema['key_schema'].append({
KEY_TYPE: RANGE,
KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE,
ATTR_NAME: attr_cls.attr_name
})
for index in cls._indexes.values():
index_schema = index._get_schema()
if isinstance(index, GlobalSecondaryIndex):
if getattr(cls.Meta, 'billing_mode', None) == PAY_PER_REQUEST_BILLING_MODE:
index_schema.pop('provisioned_throughput', None)
schema['global_secondary_indexes'].append(index_schema)
else:
schema['local_secondary_indexes'].append(index_schema)
attr_names = {key_schema[ATTR_NAME]
for index_schema in (*schema['global_secondary_indexes'], *schema['local_secondary_indexes'])
for key_schema in index_schema['key_schema']}
attr_keys = {attr[ATTR_NAME] for attr in schema['attribute_definitions']}
for attr_name in attr_names:
if attr_name not in attr_keys:
attr_cls = cls.get_attributes()[cls._dynamo_to_python_attr(attr_name)]
schema['attribute_definitions'].append({
ATTR_NAME: attr_cls.attr_name,
ATTR_TYPE: attr_cls.attr_type
})

indexes = cls._indexes.copy()
# add indexes from derived classes that we might initialize
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr is not None:
for model_cls in discriminator_attr.get_registered_subclasses(Model):
indexes.update(model_cls._indexes)

for index in indexes.values():
index._update_model_schema(schema)

return schema

def _get_save_args(self, condition: Optional[Condition] = None) -> Tuple[Iterable[Any], Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@

install_requires = [
'botocore>=1.12.54',
'typing-extensions>=3.7; python_version<"3.8"'
'typing-extensions>=4; python_version<"3.11"',
]

setup(
Expand Down

0 comments on commit 2ec2f35

Please sign in to comment.