Skip to content

Commit

Permalink
Isolating regex logic from assert util, to be easily used in other co…
Browse files Browse the repository at this point in the history
…ntexts
  • Loading branch information
romgar committed Sep 27, 2015
1 parent 4a94a4e commit d64d0a8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
8 changes: 8 additions & 0 deletions tests/compat.py
@@ -0,0 +1,8 @@

def get_model_name(klass):
if hasattr(klass._meta, 'model_name'):
model_name = klass._meta.model_name
else: # Django < 1.6
model_name = klass._meta.module_name

return model_name
41 changes: 28 additions & 13 deletions tests/utils.py
Expand Up @@ -3,14 +3,19 @@
from django.conf import settings
from django.db import connection

from .compat import get_model_name


class assert_number_queries(object):

def __init__(self, number):
self.number = number

def matched_queries(self):
return connection.queries

def query_count(self):
return len(connection.queries)
return len(self.matched_queries())

def __enter__(self):
self.DEBUG = settings.DEBUG
Expand All @@ -23,20 +28,30 @@ def __exit__(self, type, value, traceback):
settings.DEBUG = self.DEBUG


class assert_select_number_queries_on_model(assert_number_queries):
class RegexMixin(object):
regex = None

def __init__(self, model_class, number):
super(assert_select_number_queries_on_model, self).__init__(number)
self.model_class = model_class
def matched_queries(self):
matched_queries = super(RegexMixin, self).matched_queries()

def query_count(self):
if self.regex is not None:
pattern = re.compile(self.regex)
regex_compliant_queries = [query for query in connection.queries if pattern.match(query.get('sql'))]

return regex_compliant_queries

if hasattr(self.model_class._meta, 'model_name'):
model_name = self.model_class._meta.model_name
else: # < 1.6
model_name = self.model_class._meta.module_name

pattern = re.compile(r'^.*SELECT.*FROM "tests_%s".*$' % model_name)
cnt = len([query for query in connection.queries if pattern.match(query.get('sql'))])
class assert_number_of_queries_on_regex(RegexMixin, assert_number_queries):

def __init__(self, number, regex=None):
super(assert_number_of_queries_on_regex, self).__init__(number)
self.regex = regex


class assert_select_number_queries_on_model(assert_number_of_queries_on_regex):

def __init__(self, model_class, number):
super(assert_select_number_queries_on_model, self).__init__(number)

return cnt
model_name = get_model_name(model_class)
self.regex = r'^.*SELECT.*FROM "tests_%s".*$' % model_name

0 comments on commit d64d0a8

Please sign in to comment.