diff --git a/sponsors/api.py b/sponsors/api.py index 5244cbde4..7b99d0e91 100644 --- a/sponsors/api.py +++ b/sponsors/api.py @@ -7,12 +7,14 @@ from rest_framework.views import APIView from rest_framework.response import Response from sponsors.models import BenefitFeature, LogoPlacement, Sponsorship +from sponsors.models.enums import PublisherChoices, LogoPlacementChoices class LogoPlacementSerializer(serializers.Serializer): publisher = serializers.CharField() flight = serializers.CharField() sponsor = serializers.CharField() + sponsor_slug = serializers.CharField() description = serializers.CharField() logo = serializers.URLField() start_date = serializers.DateField() @@ -22,6 +24,33 @@ class LogoPlacementSerializer(serializers.Serializer): level_order = serializers.IntegerField() +class FilterLogoPlacementsSerializer(serializers.Serializer): + publisher = serializers.ChoiceField( + choices=[(c.value, c.name.replace("_", " ").title()) for c in PublisherChoices], + required=False, + ) + flight = serializers.ChoiceField( + choices=[(c.value, c.name.replace("_", " ").title()) for c in LogoPlacementChoices], + required=False, + ) + + @property + def by_publisher(self): + return self.validated_data.get("publisher") + + @property + def by_flight(self): + return self.validated_data.get("flight") + + def skip_logo(self, logo): + if self.by_publisher and self.by_publisher != logo.publisher: + return True + if self.by_flight and self.by_flight != logo.logo_place: + return True + else: + return False + + class SponsorPublisherPermission(permissions.BasePermission): message = 'Must have publisher permission.' @@ -33,18 +62,21 @@ def has_permission(self, request, view): class LogoPlacementeAPIList(APIView): - authentication_classes = [TokenAuthentication] permission_classes = [SponsorPublisherPermission] serializer_class = LogoPlacementSerializer def get(self, request, *args, **kwargs): placements = [] + logo_filters = FilterLogoPlacementsSerializer(data=request.GET) + if not logo_filters.is_valid(): + return Response(logo_filters.errors, status=400) sponsorships = Sponsorship.objects.enabled().with_logo_placement() for sponsorship in sponsorships.select_related("sponsor").iterator(): sponsor = sponsorship.sponsor base_data = { "sponsor": sponsor.name, + "sponsor_slug": sponsor.slug, "level_name": sponsorship.level_name, "level_order": sponsorship.package.order, "description": sponsor.description, @@ -55,7 +87,8 @@ def get(self, request, *args, **kwargs): } benefits = BenefitFeature.objects.filter(sponsor_benefit__sponsorship_id=sponsorship.pk) - for logo in benefits.instance_of(LogoPlacement): + logos = [l for l in benefits.instance_of(LogoPlacement) if not logo_filters.skip_logo(l)] + for logo in logos: placement = base_data.copy() placement["publisher"] = logo.publisher placement["flight"] = logo.logo_place diff --git a/sponsors/models/sponsors.py b/sponsors/models/sponsors.py index 6781fb879..2e15b6742 100644 --- a/sponsors/models/sponsors.py +++ b/sponsors/models/sponsors.py @@ -4,6 +4,7 @@ from allauth.account.models import EmailAddress from django.conf import settings from django.db import models +from django.template.defaultfilters import slugify from django.urls import reverse from django_countries.fields import CountryField from ordered_model.models import OrderedModel @@ -102,6 +103,10 @@ def primary_contact(self): except SponsorContact.DoesNotExist: return None + @property + def slug(self): + return slugify(self.name) + @property def admin_url(self): return reverse("admin:sponsors_sponsor_change", args=[self.pk]) diff --git a/sponsors/tests/test_api.py b/sponsors/tests/test_api.py index f5c2481fb..ab0b7d15c 100644 --- a/sponsors/tests/test_api.py +++ b/sponsors/tests/test_api.py @@ -1,3 +1,5 @@ +from urllib.parse import urlencode + from django.contrib.auth.models import Permission from django.urls import reverse_lazy from django.utils.text import slugify @@ -20,6 +22,15 @@ def setUp(self): self.authorization = f'Token {token.key}' self.sponsors = baker.make(Sponsor, _create_files=True, _quantity=3) + sponsorships = baker.make_recipe("sponsors.tests.finalized_sponsorship", sponsor=iter(self.sponsors), + _quantity=3) + self.sp1, self.sp2, self.sp3 = sponsorships + baker.make_recipe("sponsors.tests.logo_at_download_feature", sponsor_benefit__sponsorship=self.sp1) + baker.make_recipe("sponsors.tests.logo_at_sponsors_feature", sponsor_benefit__sponsorship=self.sp1) + baker.make_recipe("sponsors.tests.logo_at_sponsors_feature", sponsor_benefit__sponsorship=self.sp2) + baker.make_recipe("sponsors.tests.logo_at_pypi_feature", sponsor_benefit__sponsorship=self.sp3, + link_to_sponsors_page=True, describe_as_sponsor=True) + def tearDown(self): for sponsor in Sponsor.objects.all(): if sponsor.web_logo: @@ -28,12 +39,6 @@ def tearDown(self): sponsor.print_logo.delete() def test_list_logo_placement_as_expected(self): - sp1, sp2, sp3 = baker.make_recipe("sponsors.tests.finalized_sponsorship", sponsor=iter(self.sponsors), _quantity=3) - baker.make_recipe("sponsors.tests.logo_at_download_feature", sponsor_benefit__sponsorship=sp1) - baker.make_recipe("sponsors.tests.logo_at_sponsors_feature", sponsor_benefit__sponsorship=sp1) - baker.make_recipe("sponsors.tests.logo_at_sponsors_feature", sponsor_benefit__sponsorship=sp2) - baker.make_recipe("sponsors.tests.logo_at_pypi_feature", sponsor_benefit__sponsorship=sp3, link_to_sponsors_page=True, describe_as_sponsor=True) - response = self.client.get(self.url, HTTP_AUTHORIZATION=self.authorization) data = response.json() @@ -50,15 +55,15 @@ def test_list_logo_placement_as_expected(self): [p for p in data if p["publisher"] == PublisherChoices.FOUNDATION.value][0]['sponsor_url'] ) self.assertEqual( - f"http://testserver/psf/sponsors/#{slugify(sp3.sponsor.name)}", + f"http://testserver/psf/sponsors/#{slugify(self.sp3.sponsor.name)}", [p for p in data if p["publisher"] == PublisherChoices.PYPI.value][0]['sponsor_url'] ) self.assertCountEqual( - [sp1.sponsor.description, sp1.sponsor.description, sp2.sponsor.description], + [self.sp1.sponsor.description, self.sp1.sponsor.description, self.sp2.sponsor.description], [p['description'] for p in data if p["publisher"] == PublisherChoices.FOUNDATION.value] ) self.assertEqual( - [f"{sp3.sponsor.name} is a {sp3.level_name} sponsor of the Python Software Foundation."], + [f"{self.sp3.sponsor.name} is a {self.sp3.level_name} sponsor of the Python Software Foundation."], [p['description'] for p in data if p["publisher"] == PublisherChoices.PYPI.value] ) @@ -86,3 +91,41 @@ def test_user_must_have_required_permission(self): self.user.user_permissions.remove(self.permission) response = self.client.get(self.url, HTTP_AUTHORIZATION=self.authorization) self.assertEqual(403, response.status_code) + + def test_filter_sponsorship_by_publisher(self): + querystring = urlencode({ + "publisher": PublisherChoices.PYPI.value, + }) + url = f"{self.url}?{querystring}" + response = self.client.get(url, HTTP_AUTHORIZATION=self.authorization) + data = response.json() + + self.assertEqual(200, response.status_code) + self.assertEqual(1, len(data)) + self.assertEqual(self.sp3.sponsor.name, data[0]["sponsor"]) + + def test_filter_sponsorship_by_flight(self): + querystring = urlencode({ + "flight": LogoPlacementChoices.SIDEBAR.value, + }) + url = f"{self.url}?{querystring}" + response = self.client.get(url, HTTP_AUTHORIZATION=self.authorization) + data = response.json() + + self.assertEqual(200, response.status_code) + self.assertEqual(1, len(data)) + self.assertEqual(self.sp3.sponsor.name, data[0]["sponsor"]) + self.assertEqual(self.sp3.sponsor.slug, data[0]["sponsor_slug"]) + + def test_bad_request_for_invalid_filters(self): + querystring = urlencode({ + "flight": "invalid-flight", + "publisher": "invalid-publisher" + }) + url = f"{self.url}?{querystring}" + response = self.client.get(url, HTTP_AUTHORIZATION=self.authorization) + data = response.json() + + self.assertEqual(400, response.status_code) + self.assertIn("flight", data) + self.assertIn("publisher", data)