Skip to content

Commit

Permalink
Support pagination in high-level api query and scan methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Oct 17, 2017
1 parent edf3ebd commit 654bec4
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 118 deletions.
107 changes: 41 additions & 66 deletions pynamodb/models.py
Expand Up @@ -17,6 +17,7 @@
from pynamodb.types import HASH, RANGE
from pynamodb.compat import NullHandler
from pynamodb.indexes import Index, GlobalSecondaryIndex
from pynamodb.pagination import ResultIterator
from pynamodb.settings import get_settings_value
from pynamodb.constants import (
ATTR_TYPE_MAP, ATTR_DEFINITIONS, ATTR_NAME, ATTR_TYPE, KEY_SCHEMA,
Expand Down Expand Up @@ -551,27 +552,29 @@ def count(cls,
non_key_attribute_classes=non_key_attribute_classes,
filters=filters)

count_buffer = 0
last_evaluated_key = None
started = False
while not started or last_evaluated_key:
started = True
data = cls._get_connection().query(
hash_key,
range_key_condition=range_key_condition,
filter_condition=filter_condition,
index_name=index_name,
consistent_read=consistent_read,
key_conditions=key_conditions,
query_filters=query_filters,
exclusive_start_key=last_evaluated_key,
limit=limit,
select=COUNT
)
count_buffer += data.get(CAMEL_COUNT)
last_evaluated_key = data.get(LAST_EVALUATED_KEY, None)
query_args = (hash_key,)
query_kwargs = dict(
range_key_condition=range_key_condition,
filter_condition=filter_condition,
index_name=index_name,
consistent_read=consistent_read,
key_conditions=key_conditions,
query_filters=query_filters,
limit=limit,
select=COUNT
)

result_iterator = ResultIterator(
cls._get_connection().query,
query_args,
query_kwargs,
limit=limit
)

return count_buffer
# iterate through results
list(result_iterator)

return result_iterator.total_count

@classmethod
def query(cls,
Expand Down Expand Up @@ -630,6 +633,7 @@ def query(cls,
non_key_attribute_classes=non_key_attribute_classes,
filters=filters)

query_args = (hash_key,)
query_kwargs = dict(
range_key_condition=range_key_condition,
filter_condition=filter_condition,
Expand All @@ -644,27 +648,13 @@ def query(cls,
conditional_operator=conditional_operator
)

data = cls._get_connection().query(hash_key, **query_kwargs)

last_evaluated_key = data.get(LAST_EVALUATED_KEY, None)

for item in data.get(ITEMS):
if limit is not None:
if limit == 0:
return
limit -= 1
yield cls.from_raw_data(item)

while last_evaluated_key:
query_kwargs['exclusive_start_key'] = last_evaluated_key
data = cls._get_connection().query(hash_key, **query_kwargs)
for item in data.get(ITEMS):
if limit is not None:
if limit == 0:
return
limit -= 1
yield cls.from_raw_data(item)
last_evaluated_key = data.get(LAST_EVALUATED_KEY, None)
return ResultIterator(
cls._get_connection().query,
query_args,
query_kwargs,
map_fn=cls.from_raw_data,
limit=limit
)

@classmethod
def rate_limited_scan(cls,
Expand Down Expand Up @@ -771,10 +761,12 @@ def scan(cls,
filters=filters
)
key_filter.update(scan_filter)

if page_size is None:
page_size = limit

data = cls._get_connection().scan(
scan_args = ()
scan_kwargs = dict(
filter_condition=filter_condition,
exclusive_start_key=last_evaluated_key,
segment=segment,
Expand All @@ -784,31 +776,14 @@ def scan(cls,
conditional_operator=conditional_operator,
consistent_read=consistent_read
)
last_evaluated_key = data.get(LAST_EVALUATED_KEY, None)
for item in data.get(ITEMS):
yield cls.from_raw_data(item)
if limit is not None:
limit -= 1
if not limit:
return
while last_evaluated_key:
data = cls._get_connection().scan(
filter_condition=filter_condition,
exclusive_start_key=last_evaluated_key,
limit=page_size,
scan_filter=key_filter,
segment=segment,
total_segments=total_segments,
conditional_operator=conditional_operator
)
for item in data.get(ITEMS):
yield cls.from_raw_data(item)
if limit is not None:
limit -= 1
if not limit:
return

last_evaluated_key = data.get(LAST_EVALUATED_KEY, None)
return ResultIterator(
cls._get_connection().scan,
scan_args,
scan_kwargs,
map_fn=cls.from_raw_data,
limit=limit
)

@classmethod
def exists(cls):
Expand Down
63 changes: 63 additions & 0 deletions pynamodb/pagination.py
@@ -0,0 +1,63 @@
from pynamodb.constants import CAMEL_COUNT, ITEMS, LAST_EVALUATED_KEY


class ResultIterator(object):
"""
ResultIterator handles Query and Scan result pagination.
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Pagination
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.Pagination
"""
def __init__(self, operation, args, kwargs, map_fn=None, limit=None):
self._operation = operation
self._args = args
self._kwargs = kwargs
self._map_fn = map_fn
self._limit = limit
self._needs_execute = True
self._total_count = 0

def _execute(self):
data = self._operation(*self._args, **self._kwargs)
self._count = data[CAMEL_COUNT]
self._items = data.get(ITEMS) # not returned if 'Select' is set to 'COUNT'
self._last_evaluated_key = data.get(LAST_EVALUATED_KEY)
self._index = 0 if self._items else self._count
self._total_count += self._count

def __iter__(self):
return self

def __next__(self):
if self._limit == 0:
raise StopIteration

if self._needs_execute:
self._needs_execute = False
self._execute()

while self._index == self._count and self._last_evaluated_key:
self._kwargs['exclusive_start_key'] = self._last_evaluated_key
self._execute()

if self._index == self._count:
raise StopIteration

item = self._items[self._index]
self._index += 1
if self._limit is not None:
self._limit -= 1
if self._map_fn:
item = self._map_fn(item)
return item

def next(self):
return self.__next__()

@property
def last_evaluated_key(self):
return self._last_evaluated_key

@property
def total_count(self):
return self._total_count

0 comments on commit 654bec4

Please sign in to comment.