Skip to content

Commit

Permalink
Allow model to be explicitly specified in prefetch,fixes coleifer#707
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Oct 2, 2015
1 parent ac4a4b0 commit 303d215
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
15 changes: 11 additions & 4 deletions peewee.py
Expand Up @@ -4517,26 +4517,33 @@ def __ne__(self, other):
def prefetch_add_subquery(sq, subqueries):
fixed_queries = [PrefetchResult(sq)]
for i, subquery in enumerate(subqueries):
if isinstance(subquery, tuple):
subquery, target_model = subquery
else:
target_model = None
if not isinstance(subquery, Query) and issubclass(subquery, Model):
subquery = subquery.select()
subquery_model = subquery.model_class
fkf = backref = None
for j in reversed(range(i + 1)):
last_query = fixed_queries[j][0]
last_model = last_query.model_class
prefetch_result = fixed_queries[j]
last_query = prefetch_result.query
last_model = prefetch_result.model
foreign_key = subquery_model._meta.rel_for_model(last_model)
if foreign_key:
fkf = getattr(subquery_model, foreign_key.name)
to_field = getattr(last_model, foreign_key.to_field.name)
else:
backref = last_model._meta.rel_for_model(subquery_model)

if fkf or backref:
if (fkf or backref) and ((target_model is last_model) or
(target_model is None)):
break

if not (fkf or backref):
tgt_err = ' using %s' % target_model if target_model else ''
raise AttributeError('Error: unable to find foreign key for '
'query: %s' % subquery)
'query: %s%s' % (subquery, tgt_err))

if fkf:
inner_query = last_query.select(to_field)
Expand Down
7 changes: 7 additions & 0 deletions playhouse/tests/models.py
Expand Up @@ -337,6 +337,12 @@ class ServerDefaultModel(TestModel):
timestamp = DateTimeField(constraints=[
SQL('DEFAULT CURRENT_TIMESTAMP')])

class SpecialComment(TestModel):
user = ForeignKeyField(User, related_name='special_comments')
blog = ForeignKeyField(Blog, null=True, related_name='special_comments')
name = CharField()


class EmptyModel(TestModel):
pass

Expand Down Expand Up @@ -390,5 +396,6 @@ class EmptyModel(TestModel):
CommentCategory,
BlogData,
ServerDefaultModel,
SpecialComment,
EmptyModel,
]
45 changes: 45 additions & 0 deletions playhouse/tests/test_query_results.py
Expand Up @@ -879,6 +879,7 @@ class BaseTestPrefetch(ModelTestCase):
Category,
UserCategory,
Relationship,
SpecialComment,
]

user_data = [
Expand Down Expand Up @@ -1097,6 +1098,50 @@ def test_prefetch_self_join(self):

self.assertEqual(names_and_children, self.category_tree)

def test_prefetch_specific_model(self):
# User -> Blog
# -> SpecialComment (fk to user and blog)
Comment.delete().execute()
Blog.delete().execute()
User.delete().execute()
u1 = User.create(username='u1')
u2 = User.create(username='u2')
for i in range(1, 3):
for user in (u1, u2):
b = Blog.create(user=user, title='%s-b%s' % (user.username, i))
SpecialComment.create(
user=user,
blog=b,
name='%s-c%s' % (user.username, i))

u3 = User.create(username='u3')
SpecialComment.create(user=u3, name='u3-c1')

u4 = User.create(username='u4')
Blog.create(user=u4, title='u4-b1')

u5 = User.create(username='u5')

with self.assertQueryCount(3):
user_pf = prefetch(
User.select(),
Blog,
(SpecialComment, User))
results = []
for user in user_pf:
results.append((
user.username,
[b.title for b in user.blog_set_prefetch],
[c.name for c in user.special_comments_prefetch]))

self.assertEqual(results, [
('u1', ['u1-b1', 'u1-b2'], ['u1-c1', 'u1-c2']),
('u2', ['u2-b1', 'u2-b2'], ['u2-c1', 'u2-c2']),
('u3', [], ['u3-c1']),
('u4', ['u4-b1'], []),
('u5', [], []),
])


class TestAggregateRows(BaseTestPrefetch):
def test_aggregate_users(self):
Expand Down

0 comments on commit 303d215

Please sign in to comment.