diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 02fb02b8d..ac6a41996 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -6,6 +6,7 @@ import json import logging import random +import re import sys import time import uuid @@ -1240,11 +1241,12 @@ def query(self, # FilterExpression does not allow key attributes. Check for hash and range key name placeholders hash_key_placeholder = name_placeholders.get(hash_keyname) range_key_placeholder = range_keyname and name_placeholders.get(range_keyname) - if ( - hash_key_placeholder in filter_expression or - (range_key_placeholder and range_key_placeholder in filter_expression) - ): - raise ValueError("'filter_condition' cannot contain key attributes") + if re.search(hash_key_placeholder + r"\D", filter_expression): + raise ValueError("'filter_condition' cannot contain hash key. {} found in {}" + .format(hash_key_placeholder, filter_expression)) + if range_key_placeholder and re.search(range_key_placeholder + r"\D", filter_expression): + raise ValueError("'filter_condition' cannot contain range key. {} found in {}" + .format(range_key_placeholder, filter_expression)) operation_kwargs[FILTER_EXPRESSION] = filter_expression if attributes_to_get: projection_expression = create_projection_expression(attributes_to_get, name_placeholders) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 224c08407..c211e4831 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -1332,6 +1332,57 @@ def test_query(self): } self.assertEqual(req.call_args[0][1], params) + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + table_name=table_name, + hash_key="FooForum", + range_key_condition=Path('Subject').startswith('thread'), + filter_condition=Path('a2').exists() + | Path('a3').exists() + | Path('a4').exists() + | Path('a5').exists() + | Path('a6').exists() + | Path('a7').exists() + | Path('a8').exists() + | Path('a9').exists() + | Path('a10').exists() + ) + + params = { + 'TableName': 'Thread', + 'KeyConditionExpression': '(#0 = :0 AND begins_with (#1, :1))', + 'FilterExpression': + '((((((((attribute_exists (#2) ' + 'OR attribute_exists (#3)) ' + 'OR attribute_exists (#4)) ' + 'OR attribute_exists (#5)) ' + 'OR attribute_exists (#6)) ' + 'OR attribute_exists (#7)) ' + 'OR attribute_exists (#8)) ' + 'OR attribute_exists (#9)) ' + 'OR attribute_exists (#10))', + 'ExpressionAttributeNames': { + '#0': 'ForumName', + '#1': 'Subject', + '#2': 'a2', + '#3': 'a3', + '#4': 'a4', + '#5': 'a5', + '#6': 'a6', + '#7': 'a7', + '#8': 'a8', + '#9': 'a9', + '#10': 'a10' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'FooForum'}, + ':1': {'S': 'thread'} + }, + 'ReturnConsumedCapacity': 'TOTAL' + } + self.assertEqual(req.call_args[0][1], params) + def test_scan(self): """ Connection.scan