diff --git a/peewee.py b/peewee.py index 6741c9b00..4e274fc3e 100644 --- a/peewee.py +++ b/peewee.py @@ -4517,13 +4517,18 @@ def __ne__(self, other): def prefetch_add_subquery(sq, subqueries): fixed_queries = [PrefetchResult(sq)] for i, subquery in enumerate(subqueries): + if isinstance(subquery, tuple): + subquery, target_model = subquery + else: + target_model = None if not isinstance(subquery, Query) and issubclass(subquery, Model): subquery = subquery.select() subquery_model = subquery.model_class fkf = backref = None for j in reversed(range(i + 1)): - last_query = fixed_queries[j][0] - last_model = last_query.model_class + prefetch_result = fixed_queries[j] + last_query = prefetch_result.query + last_model = prefetch_result.model foreign_key = subquery_model._meta.rel_for_model(last_model) if foreign_key: fkf = getattr(subquery_model, foreign_key.name) @@ -4531,12 +4536,14 @@ def prefetch_add_subquery(sq, subqueries): else: backref = last_model._meta.rel_for_model(subquery_model) - if fkf or backref: + if (fkf or backref) and ((target_model is last_model) or + (target_model is None)): break if not (fkf or backref): + tgt_err = ' using %s' % target_model if target_model else '' raise AttributeError('Error: unable to find foreign key for ' - 'query: %s' % subquery) + 'query: %s%s' % (subquery, tgt_err)) if fkf: inner_query = last_query.select(to_field) diff --git a/playhouse/tests/models.py b/playhouse/tests/models.py index 0820059d0..8abee2442 100644 --- a/playhouse/tests/models.py +++ b/playhouse/tests/models.py @@ -337,6 +337,12 @@ class ServerDefaultModel(TestModel): timestamp = DateTimeField(constraints=[ SQL('DEFAULT CURRENT_TIMESTAMP')]) +class SpecialComment(TestModel): + user = ForeignKeyField(User, related_name='special_comments') + blog = ForeignKeyField(Blog, null=True, related_name='special_comments') + name = CharField() + + class EmptyModel(TestModel): pass @@ -390,5 +396,6 @@ class EmptyModel(TestModel): CommentCategory, BlogData, ServerDefaultModel, + SpecialComment, EmptyModel, ] diff --git a/playhouse/tests/test_query_results.py b/playhouse/tests/test_query_results.py index 16809fa68..82efd564a 100644 --- a/playhouse/tests/test_query_results.py +++ b/playhouse/tests/test_query_results.py @@ -879,6 +879,7 @@ class BaseTestPrefetch(ModelTestCase): Category, UserCategory, Relationship, + SpecialComment, ] user_data = [ @@ -1097,6 +1098,50 @@ def test_prefetch_self_join(self): self.assertEqual(names_and_children, self.category_tree) + def test_prefetch_specific_model(self): + # User -> Blog + # -> SpecialComment (fk to user and blog) + Comment.delete().execute() + Blog.delete().execute() + User.delete().execute() + u1 = User.create(username='u1') + u2 = User.create(username='u2') + for i in range(1, 3): + for user in (u1, u2): + b = Blog.create(user=user, title='%s-b%s' % (user.username, i)) + SpecialComment.create( + user=user, + blog=b, + name='%s-c%s' % (user.username, i)) + + u3 = User.create(username='u3') + SpecialComment.create(user=u3, name='u3-c1') + + u4 = User.create(username='u4') + Blog.create(user=u4, title='u4-b1') + + u5 = User.create(username='u5') + + with self.assertQueryCount(3): + user_pf = prefetch( + User.select(), + Blog, + (SpecialComment, User)) + results = [] + for user in user_pf: + results.append(( + user.username, + [b.title for b in user.blog_set_prefetch], + [c.name for c in user.special_comments_prefetch])) + + self.assertEqual(results, [ + ('u1', ['u1-b1', 'u1-b2'], ['u1-c1', 'u1-c2']), + ('u2', ['u2-b1', 'u2-b2'], ['u2-c1', 'u2-c2']), + ('u3', [], ['u3-c1']), + ('u4', ['u4-b1'], []), + ('u5', [], []), + ]) + class TestAggregateRows(BaseTestPrefetch): def test_aggregate_users(self):