diff --git a/pinax/api/resource.py b/pinax/api/resource.py index dc4fd9f..704a24a 100644 --- a/pinax/api/resource.py +++ b/pinax/api/resource.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import collections import datetime from collections import namedtuple @@ -161,7 +162,16 @@ def get_attr(self, attr): return resolve_value(value) def get_relationship(self, related_name, rel): - return getattr(self.obj, related_name) + if rel.collection: + iterator = getattr(self.obj, related_name) + if not isinstance(iterator, collections.Iterable): + if not hasattr(iterator, "all"): + raise TypeError("Relationship {} must be iterable or QuerySet".format(related_name)) + else: + iterator = iterator.all() + return iterator + else: + return getattr(self.obj, related_name) def set_attr(self, attr, value): if hasattr(self, attr.obj_attr): @@ -226,9 +236,9 @@ def serialize(self, links=False, request=None): rel_initial["links"] = rel_links rel_obj = relationships.setdefault(name, rel_initial) if rel.collection: - qs = self.get_relationship(name, rel).all() + iterable = self.get_relationship(name, rel) rel_data = rel_obj.setdefault("data", []) - for v in qs: + for v in iterable: rel_data.append(rel.resource_class()(v).identifier.as_dict()) else: v = self.get_relationship(name, rel) @@ -275,13 +285,13 @@ def resolve_include(resource, path, included): raise SerializationError("'{}' is not a valid relationship to include".format(head)) rel = resource.relationships[head] if rel.collection: - for obj in getattr(resource.obj, head).all(): + for obj in resource.get_relationship(head, rel): r = rel.resource_class()(obj) if rest: resolve_include(r, rest, included) included.add(r) else: - r = rel.resource_class()(getattr(resource.obj, head)) + r = rel.resource_class()(resource.get_relationship(head, rel)) included.add(r) diff --git a/pinax/api/tests/models.py b/pinax/api/tests/models.py index 52315aa..488be2e 100644 --- a/pinax/api/tests/models.py +++ b/pinax/api/tests/models.py @@ -1,13 +1,6 @@ from django.db import models -class ArticleTag(models.Model): - name = models.CharField(max_length=50) - - def __str__(self): - return self.name - - class Author(models.Model): name = models.CharField(max_length=50) @@ -18,4 +11,16 @@ def __str__(self): class Article(models.Model): title = models.CharField(max_length=100) author = models.ForeignKey(Author) - tags = models.ManyToManyField(ArticleTag) + + @property + def tags(self): + for tag in self.articletag_set.all(): + yield tag + + +class ArticleTag(models.Model): + article = models.ForeignKey(Article) + name = models.CharField(max_length=50) + + def __str__(self): + return self.name diff --git a/pinax/api/tests/relationships.py b/pinax/api/tests/relationships.py index 3896bed..87e957d 100644 --- a/pinax/api/tests/relationships.py +++ b/pinax/api/tests/relationships.py @@ -1,7 +1,7 @@ from pinax import api from pinax.api.exceptions import ErrorResponse -from .models import Article +from .models import Article, ArticleTag class ArticleTagCollectionEndpointSet(api.RelationshipEndpointSet): @@ -35,7 +35,7 @@ def create(self, request, pk): with self.validate(self.resource_class, collection=True) as resources: tags = [resource.obj.name for resource in resources] for tag in tags: - self.article.tags.create(name=tag) + ArticleTag.objects.create(name=tag, article=self.article) return self.render(None) def update(self, request, pk): @@ -44,16 +44,16 @@ def update(self, request, pk): """ with self.validate(self.resource_class, collection=True) as resources: tags = [resource.obj.name for resource in resources] - self.article.tags.clear() + ArticleTag.objects.filter(article=self.article).delete() for tag in tags: - self.article.tags.create(name=tag) + ArticleTag.objects.create(name=tag, article=self.article) return self.render(None) def retrieve(self, request, pk): """ Identifier: List tags for an Article """ - tags = self.article.tags.all() + tags = ArticleTag.objects.filter(article=self.article) return self.render(self.resource_class.from_queryset(tags)) def destroy(self, request, pk): @@ -62,7 +62,7 @@ def destroy(self, request, pk): """ with self.validate(self.resource_class, collection=True) as resources: tags = [resource.obj.name for resource in resources] - self.article.tags.filter(name__in=tags).delete() + ArticleTag.objects.filter(article=self.article, name__in=tags).delete() return self.render_delete() diff --git a/pinax/api/tests/test_relationships.py b/pinax/api/tests/test_relationships.py index 86d46ba..6abbf1b 100644 --- a/pinax/api/tests/test_relationships.py +++ b/pinax/api/tests/test_relationships.py @@ -8,6 +8,7 @@ from .models import ( Article, + ArticleTag, Author, ) @@ -69,7 +70,7 @@ def test_create(self): } } ) - tags = self.article1.tags.values_list("name", flat=True) + tags = ArticleTag.objects.filter(article=self.article1).values_list("name", flat=True) self.assertIn(new_tag, tags) self.assertIn(another_new_tag, tags) @@ -79,9 +80,9 @@ def test_retrieve(self): """ # Create two tags for article1. first_tag = "Pinax" - self.article1.tags.create(name=first_tag) + ArticleTag.objects.create(name=first_tag, article=self.article1) second_tag = "Kel" - self.article1.tags.create(name=second_tag) + ArticleTag.objects.create(name=second_tag, article=self.article1) with mock.patch("pinax.api.authentication.Anonymous.authenticate", autospec=True) as mock_authenticate: mock_authenticate.return_value = AnonymousUser() @@ -131,11 +132,11 @@ def test_destroy(self): """ # Create a tag which should remain. remain_tag = "Pinax" - self.article1.tags.create(name=remain_tag) + ArticleTag.objects.create(name=remain_tag, article=self.article1) # Create a tag for removal. remove_tag = "Kel" - self.article1.tags.create(name=remove_tag) + ArticleTag.objects.create(name=remove_tag, article=self.article1) post_data = { "data": [ @@ -159,7 +160,7 @@ def test_destroy(self): self.assertEqual(response.status_code, 204) self.assertEqual(response["Content-Type"], "application/vnd.api+json") - article_tags = self.article1.tags.values_list("name", flat=True) + article_tags = ArticleTag.objects.filter(article=self.article1).values_list("name", flat=True) self.assertNotIn(remove_tag, article_tags) self.assertIn(remain_tag, article_tags) @@ -169,9 +170,9 @@ def test_get_all(self): """ # Create a tag for each Article. first_tag = "Pinax" - self.article1.tags.create(name=first_tag) + ArticleTag.objects.create(name=first_tag, article=self.article1) second_tag = "Kel" - self.article2.tags.create(name=second_tag) + ArticleTag.objects.create(name=second_tag, article=self.article2) with mock.patch("pinax.api.authentication.Anonymous.authenticate", autospec=True) as mock_authenticate: mock_authenticate.return_value = AnonymousUser() @@ -232,8 +233,10 @@ def test_update(self): Ensure we can completely replace all Article tags. """ # Create several tags which will get replaced. - self.article1.tags.create(name="Pinax") - self.article1.tags.create(name="Kel") + first_tag = "Pinax" + ArticleTag.objects.create(name=first_tag, article=self.article1) + second_tag = "Kel" + ArticleTag.objects.create(name=second_tag, article=self.article1) new_tag = "Futurama" another_new_tag = "Archer" @@ -275,7 +278,7 @@ def test_update(self): } ) - business_tags = self.article1.tags.values_list("name", flat=True) + business_tags = ArticleTag.objects.filter(article=self.article1).values_list("name", flat=True) self.assertSetEqual(set([new_tag, another_new_tag]), set(list(business_tags))) def test_get_article_matching_tag(self): @@ -283,15 +286,15 @@ def test_get_article_matching_tag(self): """ # Create a separate tag for each Article first_tag = "Pinax" - self.article1.tags.create(name=first_tag) + ArticleTag.objects.create(name=first_tag, article=self.article1) second_tag = "Kel" - self.article2.tags.create(name=second_tag) + ArticleTag.objects.create(name=second_tag, article=self.article2) # Create a third tag used in both Articles. third_tag = "Club" - self.article1.tags.create(name=third_tag) - self.article2.tags.create(name=third_tag) + ArticleTag.objects.create(name=third_tag, article=self.article1) + ArticleTag.objects.create(name=third_tag, article=self.article2) with mock.patch("pinax.api.authentication.Anonymous.authenticate", autospec=True) as mock_authenticate: mock_authenticate.return_value = AnonymousUser() diff --git a/pinax/api/tests/viewsets.py b/pinax/api/tests/viewsets.py index 5b6fdbd..73a87c6 100644 --- a/pinax/api/tests/viewsets.py +++ b/pinax/api/tests/viewsets.py @@ -48,7 +48,7 @@ def list(self, request): qs = self.get_queryset() tag_querystring = request.GET.get("tag", "") if tag_querystring: - qs = qs.filter(tags__name__in=[tag_querystring]) + qs = qs.filter(articletag__name__in=[tag_querystring]) return self.render(self.resource_class.from_queryset(qs)) def retrieve(self, request, pk):