Skip to content

Commit

Permalink
Support complete condition expression syntax. (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft committed Aug 17, 2017
1 parent 3560228 commit ea2a2ee
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 47 deletions.
21 changes: 21 additions & 0 deletions pynamodb/attributes.py
Expand Up @@ -76,6 +76,11 @@ def __eq__(self, other):
return self is other
return AttributePath(self).__eq__(other)

def __ne__(self, other):
if other is None or isinstance(other, Attribute): # handle object identity comparison
return self is not other
return AttributePath(self).__ne__(other)

def __lt__(self, other):
return AttributePath(self).__lt__(other)

Expand All @@ -94,9 +99,25 @@ def __getitem__(self, idx):
def between(self, lower, upper):
return AttributePath(self).between(lower, upper)

def is_in(self, *values):
return AttributePath(self).is_in(*values)

def exists(self):
return AttributePath(self).exists()

def not_exists(self):
return AttributePath(self).not_exists()

def is_type(self):
# What makes sense here? Are we using this to check if deserialization will be successful?
return AttributePath(self).is_type(ATTR_TYPE_MAP[self.attr_type])

def startswith(self, prefix):
return AttributePath(self).startswith(prefix)

def contains(self, item):
return AttributePath(self).contains(item)


class AttributePath(Path):

Expand Down
214 changes: 178 additions & 36 deletions pynamodb/expressions/condition.py
@@ -1,25 +1,20 @@
from copy import copy
from pynamodb.constants import AND, BETWEEN
from pynamodb.constants import AND, ATTR_TYPE_MAP, BETWEEN, IN, OR, SHORT_ATTR_TYPES, STRING_SHORT
from pynamodb.expressions.util import get_value_placeholder, substitute_names
from six.moves import range


class Path(object):

def __init__(self, path, attribute_name=False):
self.path = path
self.attribute_name = attribute_name

def __getitem__(self, idx):
# list dereference operator
if not isinstance(idx, int):
raise TypeError("list indices must be integers, not {0}".format(type(idx).__name__))
element_path = copy(self)
element_path.path = '{0}[{1}]'.format(self.path, idx)
return element_path
class Operand(object):
"""
Operand is the base class for objects that support creating conditions from comparators.
"""

def __eq__(self, other):
return self._compare('=', other)

def __ne__(self, other):
return self._compare('<>', other)

def __lt__(self, other):
return self._compare('<', other)

Expand All @@ -41,13 +36,85 @@ def between(self, lower, upper):
# work but similar expressions like value1 <= attribute & attribute < value2 fail seems too brittle.
return Between(self, self._serialize(lower), self._serialize(upper))

def is_in(self, *values):
values = [self._serialize(value) for value in values]
return In(self, *values)

def _serialize(self, value):
# Check to see if value is already serialized
if isinstance(value, dict) and len(value) == 1 and list(value.keys())[0] in SHORT_ATTR_TYPES:
return value
# Serialize value based on its type
from pynamodb.attributes import _get_class_for_serialize
attr_class = _get_class_for_serialize(value)
return {ATTR_TYPE_MAP[attr_class.attr_type]: attr_class.serialize(value)}


class Size(Operand):
"""
Size is a special operand that represents the result of calling the 'size' function on a Path operand.
"""

def __init__(self, path):
# prevent circular import -- AttributePath imports Path
from pynamodb.attributes import Attribute, AttributePath
if isinstance(path, Path):
self.path = Path
elif isinstance(path, Attribute):
self.path = AttributePath(path)
else:
self.path = Path(path)

def _serialize(self, value):
if not isinstance(value, int):
raise TypeError("size must be compared to an integer, not {0}".format(type(value).__name__))
return {'N': str(value)}

def __str__(self):
return "size({0})".format(self.path)

def __repr__(self):
return "Size({0})".format(repr(self.path))


# match dynamo function syntax: size(path)
def size(path):
return Size(path)


class Path(Operand):
"""
Path is an operand that represents either an attribute name or document path.
In addition to supporting comparisons, Path also supports creating conditions from functions.
"""

def __init__(self, path, attribute_name=False):
self.path = path
self.attribute_name = attribute_name

def __getitem__(self, idx):
# list dereference operator
if not isinstance(idx, int):
raise TypeError("list indices must be integers, not {0}".format(type(idx).__name__))
element_path = copy(self)
element_path.path = '{0}[{1}]'.format(self.path, idx)
return element_path

def exists(self):
return Exists(self)

def not_exists(self):
return NotExists(self)

def is_type(self, attr_type):
return IsType(self, attr_type)

def startswith(self, prefix):
# A 'pythonic' replacement for begins_with to match string behavior (e.g. "foo".startswith("f"))
return BeginsWith(self, self._serialize(prefix))

def _serialize(self, value):
# Allow subclasses to define value serialization.
return value
def contains(self, item):
return Contains(self, self._serialize(item))

def __str__(self):
if self.attribute_name and '.' in self.path:
Expand All @@ -67,34 +134,47 @@ def __init__(self, path, operator, *values):
self.path = path
self.operator = operator
self.values = values
self.logical_operator = None
self.other_condition = None

def serialize(self, placeholder_names, expression_attribute_values):
split = not self.path.attribute_name
path = substitute_names(self.path.path, placeholder_names, split=split)
values = [get_value_placeholder(value, expression_attribute_values) for value in self.values]
condition = self.format_string.format(*values, path=path, operator=self.operator)
if self.logical_operator:
other_condition = self.other_condition.serialize(placeholder_names, expression_attribute_values)
return '{0} {1} {2}'.format(condition, self.logical_operator, other_condition)
return condition
path = self._get_path(self.path, placeholder_names)
values = self._get_values(placeholder_names, expression_attribute_values)
return self.format_string.format(*values, path=path, operator=self.operator)

def _get_path(self, path, placeholder_names):
if isinstance(path, Path):
split = not path.attribute_name
return substitute_names(path.path, placeholder_names, split=split)
elif isinstance(path, Size):
return "size ({0})".format(self._get_path(path.path, placeholder_names))
else:
return path

def _get_values(self, placeholder_names, expression_attribute_values):
return [
value.serialize(placeholder_names, expression_attribute_values)
if isinstance(value, Condition)
else get_value_placeholder(value, expression_attribute_values)
for value in self.values
]

def __and__(self, other):
if not isinstance(other, Condition):
raise TypeError("unsupported operand type(s) for &: '{0}' and '{1}'",
self.__class__.__name__, other.__class__.__name__)
self.logical_operator = AND
self.other_condition = other
return self
return And(self, other)

def __or__(self, other):
if not isinstance(other, Condition):
raise TypeError("unsupported operand type(s) for |: '{0}' and '{1}'",
self.__class__.__name__, other.__class__.__name__)
return Or(self, other)

def __invert__(self):
return Not(self)

def __repr__(self):
values = [value.items()[0][1] for value in self.values]
condition = self.format_string.format(*values, path=self.path, operator = self.operator)
if self.logical_operator:
other_conditions = repr(self.other_condition)
return '{0} {1} {2}'.format(condition, self.logical_operator, other_conditions)
return condition
values = [repr(value) if isinstance(value, Condition) else value.items()[0][1] for value in self.values]
return self.format_string.format(*values, path=self.path, operator = self.operator)

def __nonzero__(self):
# Prevent users from accidentally comparing the condition object instead of the attribute instance
Expand All @@ -112,8 +192,70 @@ def __init__(self, path, lower, upper):
super(Between, self).__init__(path, BETWEEN, lower, upper)


class In(Condition):
def __init__(self, path, *values):
super(In, self).__init__(path, IN, *values)
list_format = ', '.join('{' + str(i) + '}' for i in range(len(values)))
self.format_string = '{path} {operator} (' + list_format + ')'


class Exists(Condition):
format_string = '{operator} ({path})'

def __init__(self, path):
super(Exists, self).__init__(path, 'attribute_exists')


class NotExists(Condition):
format_string = '{operator} ({path})'

def __init__(self, path):
super(NotExists, self).__init__(path, 'attribute_not_exists')


class IsType(Condition):
format_string = '{operator} ({path}, {0})'

def __init__(self, path, attr_type):
if attr_type not in SHORT_ATTR_TYPES:
raise ValueError("{0} is not a valid attribute type. Must be one of {1}".format(
attr_type, SHORT_ATTR_TYPES))
super(IsType, self).__init__(path, 'attribute_type', {STRING_SHORT: attr_type})


class BeginsWith(Condition):
format_string = '{operator} ({path}, {0})'

def __init__(self, path, prefix):
super(BeginsWith, self).__init__(path, 'begins_with', prefix)


class Contains(Condition):
format_string = '{operator} ({path}, {0})'

def __init__(self, path, item):
(attr_type, value), = item.items()
if attr_type != STRING_SHORT:
raise ValueError("{0} must be a string".format(value))
super(Contains, self).__init__(path, 'contains', item)


class And(Condition):
format_string = '({0} {operator} {1})'

def __init__(self, condition1, condition2):
super(And, self).__init__(None, AND, condition1, condition2)


class Or(Condition):
format_string = '({0} {operator} {1})'

def __init__(self, condition1, condition2):
super(Or, self).__init__(None, OR, condition1, condition2)


class Not(Condition):
format_string = '({operator} {0})'

def __init__(self, condition):
super(Not, self).__init__(None, 'NOT', condition)
4 changes: 2 additions & 2 deletions pynamodb/tests/test_base_connection.py
Expand Up @@ -1348,7 +1348,7 @@ def test_query(self):
'ScanIndexForward': True,
'Select': 'ALL_ATTRIBUTES',
'ReturnConsumedCapacity': 'TOTAL',
'KeyConditionExpression': '#0 = :0 AND begins_with (#1, :1)',
'KeyConditionExpression': '(#0 = :0 AND begins_with (#1, :1))',
'ExpressionAttributeNames': {
'#0': 'ForumName',
'#1': 'Subject'
Expand All @@ -1374,7 +1374,7 @@ def test_query(self):
)
params = {
'ReturnConsumedCapacity': 'TOTAL',
'KeyConditionExpression': '#0 = :0 AND begins_with (#1, :1)',
'KeyConditionExpression': '(#0 = :0 AND begins_with (#1, :1))',
'ExpressionAttributeNames': {
'#0': 'ForumName',
'#1': 'Subject'
Expand Down

0 comments on commit ea2a2ee

Please sign in to comment.