diff --git a/tests/compat.py b/tests/compat.py new file mode 100644 index 0000000..ab84e55 --- /dev/null +++ b/tests/compat.py @@ -0,0 +1,8 @@ + +def get_model_name(klass): + if hasattr(klass._meta, 'model_name'): + model_name = klass._meta.model_name + else: # Django < 1.6 + model_name = klass._meta.module_name + + return model_name diff --git a/tests/utils.py b/tests/utils.py index 621c0ca..67780c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,14 +3,19 @@ from django.conf import settings from django.db import connection +from .compat import get_model_name + class assert_number_queries(object): def __init__(self, number): self.number = number + def matched_queries(self): + return connection.queries + def query_count(self): - return len(connection.queries) + return len(self.matched_queries()) def __enter__(self): self.DEBUG = settings.DEBUG @@ -23,20 +28,30 @@ def __exit__(self, type, value, traceback): settings.DEBUG = self.DEBUG -class assert_select_number_queries_on_model(assert_number_queries): +class RegexMixin(object): + regex = None - def __init__(self, model_class, number): - super(assert_select_number_queries_on_model, self).__init__(number) - self.model_class = model_class + def matched_queries(self): + matched_queries = super(RegexMixin, self).matched_queries() - def query_count(self): + if self.regex is not None: + pattern = re.compile(self.regex) + regex_compliant_queries = [query for query in connection.queries if pattern.match(query.get('sql'))] + + return regex_compliant_queries - if hasattr(self.model_class._meta, 'model_name'): - model_name = self.model_class._meta.model_name - else: # < 1.6 - model_name = self.model_class._meta.module_name - pattern = re.compile(r'^.*SELECT.*FROM "tests_%s".*$' % model_name) - cnt = len([query for query in connection.queries if pattern.match(query.get('sql'))]) +class assert_number_of_queries_on_regex(RegexMixin, assert_number_queries): + + def __init__(self, number, regex=None): + super(assert_number_of_queries_on_regex, self).__init__(number) + self.regex = regex + + +class assert_select_number_queries_on_model(assert_number_of_queries_on_regex): + + def __init__(self, model_class, number): + super(assert_select_number_queries_on_model, self).__init__(number) - return cnt + model_name = get_model_name(model_class) + self.regex = r'^.*SELECT.*FROM "tests_%s".*$' % model_name