diff --git a/comments/api/tests.py b/comments/api/tests.py index a2691e4..e80797c 100644 --- a/comments/api/tests.py +++ b/comments/api/tests.py @@ -101,3 +101,33 @@ def test_update(self): self.assertEqual(comment.created_at, before_created_at) self.assertNotEqual(comment.created_at, now) self.assertNotEqual(comment.updated_at, before_updated_at) + + def test_list(self): + # must have tweet_id + response = self.anonymous_client.get(COMMENT_URL) + self.assertEqual(response.status_code, 400) + + # you can visit with tweet_id + # but at the beginning, no comment + response = self.anonymous_client.get( + COMMENT_URL, + {'tweet_id' : self.tweet.id}, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data['comments']), 0) + + # comments order by time + self.create_comment(self.linghu, self.tweet, '1') + self.create_comment(self.dongxie, self.tweet, '2') + self.create_comment(self.dongxie, self.create_tweet(self.dongxie), '3') + response = self.anonymous_client.get(COMMENT_URL, {'tweet_id' : self.tweet.id}) + self.assertEqual(len(response.data['comments']), 2) + self.assertEqual(response.data['comments'][0]['content'], '1') + self.assertEqual(response.data['comments'][1]['content'], '2') + + #同时提供 userid 和 tweetid, 只有 tweetid 会在 filter 中生效 + response = self.anonymous_client.get(COMMENT_URL, { + 'tweet_id' : self.tweet.id, + 'user_id' : self.linghu.id, + }) + self.assertEqual(len(response.data['comments']), 2) \ No newline at end of file diff --git a/comments/api/views.py b/comments/api/views.py index 92e1e68..0505fc9 100644 --- a/comments/api/views.py +++ b/comments/api/views.py @@ -13,9 +13,10 @@ class CommentViewSet(viewsets.GenericViewSet): serializer_class = CommentSerializerForCreate queryset = Comment.objects.all() + filterset_fields = ('tweet_id', ) # POST /api/comments/ -> create - # GET /api/comments/ -> list + # GET /api/comments/?tweet_id= -> list # Get /api/comments/1/ -> retrieve # DELETE /api/comments/1/ -> destroy # PATCH /api/comments/1/ -> partial_update @@ -28,6 +29,25 @@ def get_permissions(self): return [IsAuthenticated(), IsObjectOwner()] return [AllowAny()] + def list(self, request, *args, **kwargs): + if 'tweet_id' not in request.query_params: + return Response({ + 'message' : 'missing tweet_id in request', + 'success' : False, + }, status=status.HTTP_400_BAD_REQUEST) + # 这种写法在后期需要添加其他属性进行筛选就很方便 + queryset = self.get_queryset() + comments = self.filter_queryset(queryset)\ + .prefetch_related('user')\ + .order_by('created_at') + # tweet_id = request.query_params['tweet_id'] + # comments = Comment.objects.filter(tweet_id=tweet_id) + serializer = CommentSerializer(comments, many=True) + return Response({ + 'comments' : serializer.data, + }, status=status.HTTP_200_OK) + + def create(self, request, *args, **kwargs): data = { diff --git a/comments/models.py b/comments/models.py index a8c6a84..6fa6988 100644 --- a/comments/models.py +++ b/comments/models.py @@ -6,7 +6,7 @@ class Comment(models.Model): user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True) tweet = models.ForeignKey(Tweet, on_delete=models.SET_NULL, null=True) - content = models.CharField(max_length=140) + content = models.TextField(max_length=140) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) diff --git a/requirements.txt b/requirements.txt index fcfd2c0..9e67f19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ chardet==3.0.4 cryptography==2.1.4 Django==3.1.3 django-debug-toolbar==3.2.4 +django-filter==21.1 djangorestframework==3.12.2 idna==2.6 keyring==10.6.0 diff --git a/twitter/settings.py b/twitter/settings.py index 651e9a5..787a555 100644 --- a/twitter/settings.py +++ b/twitter/settings.py @@ -40,6 +40,8 @@ 'django.contrib.staticfiles', # third party 'rest_framework', + 'debug_toolbar', + 'django_filters', # project apps 'accounts', @@ -47,8 +49,7 @@ 'friendships', 'newsfeeds', 'comments', - # debug tool - 'debug_toolbar', + ] MIDDLEWARE = [ @@ -118,7 +119,10 @@ REST_FRAMEWORK = { 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', - 'PAGE_SIZE': 10 + 'PAGE_SIZE': 10, + 'DEFAULT_FILTER_BACKENDS' : [ + 'django_filters.rest_framework.DjangoFilterBackend', + ], } @@ -140,3 +144,8 @@ # https://docs.djangoproject.com/en/3.1/howto/static-files/ STATIC_URL = '/static/' + +try: + from .local_settings import * +except: + pass