Skip to content

Commit

Permalink
Merge pull request #510 from bast0006/bast0006-new-infraction-filters
Browse files Browse the repository at this point in the history
Add new infraction filters for the infraction rescheduler
  • Loading branch information
ChrisLovering committed Jun 4, 2021
2 parents 03c787b + b076394 commit 7d81829
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 6 deletions.
174 changes: 168 additions & 6 deletions pydis_site/apps/api/tests/test_infractions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from datetime import datetime as dt, timedelta, timezone
from unittest.mock import patch
from urllib.parse import quote
Expand All @@ -16,7 +17,7 @@ def setUp(self):
self.client.force_authenticate(user=None)

def test_detail_lookup_returns_401(self):
url = reverse('bot:infraction-detail', args=(5,), host='api')
url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.get(url)

self.assertEqual(response.status_code, 401)
Expand All @@ -34,7 +35,7 @@ def test_create_returns_401(self):
self.assertEqual(response.status_code, 401)

def test_partial_update_returns_401(self):
url = reverse('bot:infraction-detail', args=(5,), host='api')
url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.patch(url, data={'reason': 'Have a nice day.'})

self.assertEqual(response.status_code, 401)
Expand All @@ -44,7 +45,7 @@ class InfractionTests(APISubdomainTestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create(
id=5,
id=6,
name='james',
discriminator=1,
)
Expand All @@ -64,6 +65,30 @@ def setUpTestData(cls):
reason='James is an ass, and we won\'t be working with him again.',
active=False
)
cls.mute_permanent = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='mute',
reason='He has a filthy mouth and I am his soap.',
active=True,
expires_at=None
)
cls.superstar_expires_soon = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='superstar',
reason='This one doesn\'t matter anymore.',
active=True,
expires_at=datetime.datetime.utcnow() + datetime.timedelta(hours=5)
)
cls.voiceban_expires_later = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='voice_ban',
reason='Jet engine mic',
active=True,
expires_at=datetime.datetime.utcnow() + datetime.timedelta(days=5)
)

def test_list_all(self):
"""Tests the list-view, which should be ordered by inserted_at (newest first)."""
Expand All @@ -73,9 +98,12 @@ def test_list_all(self):
self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 2)
self.assertEqual(infractions[0]['id'], self.ban_inactive.id)
self.assertEqual(infractions[1]['id'], self.ban_hidden.id)
self.assertEqual(len(infractions), 5)
self.assertEqual(infractions[0]['id'], self.voiceban_expires_later.id)
self.assertEqual(infractions[1]['id'], self.superstar_expires_soon.id)
self.assertEqual(infractions[2]['id'], self.mute_permanent.id)
self.assertEqual(infractions[3]['id'], self.ban_inactive.id)
self.assertEqual(infractions[4]['id'], self.ban_hidden.id)

def test_filter_search(self):
url = reverse('bot:infraction-list', host='api')
Expand All @@ -98,6 +126,140 @@ def test_filter_field(self):
self.assertEqual(len(infractions), 1)
self.assertEqual(infractions[0]['id'], self.ban_hidden.id)

def test_filter_permanent_false(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=mute&permanent=false')

self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 0)

def test_filter_permanent_true(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=mute&permanent=true')

self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(infractions[0]['id'], self.mute_permanent.id)

def test_filter_after(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?type=superstar&expires_after={target_time.isoformat()}')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 0)

def test_filter_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?type=superstar&expires_before={target_time.isoformat()}')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 1)
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)

def test_filter_after_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?expires_after=gibberish')

self.assertEqual(response.status_code, 400)
self.assertEqual(list(response.json())[0], "expires_after")

def test_filter_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?expires_before=000000000')

self.assertEqual(response.status_code, 400)
self.assertEqual(list(response.json())[0], "expires_before")

def test_after_before_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=4)
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
response = self.client.get(
f'{url}?expires_before={target_time_late.isoformat()}'
f'&expires_after={target_time.isoformat()}'
)

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 1)
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)

def test_after_after_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
response = self.client.get(
f'{url}?expires_before={target_time.isoformat()}'
f'&expires_after={target_time_late.isoformat()}'
)

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertIn("expires_before", errors)
self.assertIn("expires_after", errors)

def test_permanent_after_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?permanent=true&expires_after={target_time.isoformat()}')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("permanent", errors[0])

def test_permanent_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?permanent=true&expires_before={target_time.isoformat()}')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("permanent", errors[0])

def test_nonpermanent_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
response = self.client.get(
f'{url}?permanent=false&expires_before={target_time.isoformat()}'
)

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 1)
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)

def test_filter_manytypes(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?types=mute,ban')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 3)

def test_types_type_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?types=mute,ban&type=superstar')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("types", errors[0])

def test_sort_expiresby(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?ordering=expires_at&permanent=false')
self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 3)
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
self.assertEqual(infractions[1]['id'], self.voiceban_expires_later.id)
self.assertEqual(infractions[2]['id'], self.ban_hidden.id)

def test_returns_empty_for_no_match(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=ban&search=poop')
Expand Down
73 changes: 73 additions & 0 deletions pydis_site/apps/api/viewsets/bot/infraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import datetime

from django.db.models import QuerySet
from django.http.request import HttpRequest
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.decorators import action
Expand Down Expand Up @@ -43,10 +46,17 @@ class InfractionViewSet(
- **offset** `int`: the initial index from which to return the results (default 0)
- **search** `str`: regular expression applied to the infraction's reason
- **type** `str`: the type of the infraction
- **types** `str`: comma separated sequence of types to filter for
- **user__id** `int`: snowflake of the user to which the infraction was applied
- **ordering** `str`: comma-separated sequence of fields to order the returned results
- **permanent** `bool`: whether or not to retrieve permanent infractions (default True)
- **expires_after** `isodatetime`: the earliest expires_at time to return infractions for
- **expires_before** `isodatetime`: the latest expires_at time to return infractions for
Invalid query parameters are ignored.
Only one of `type` and `types` may be provided. If both `expires_before` and `expires_after`
are provided, `expires_after` must come after `expires_before`.
If `permanent` is provided and true, `expires_before` and `expires_after` must not be provided.
#### Response format
Response is paginated but the result is returned without any pagination metadata.
Expand Down Expand Up @@ -156,6 +166,69 @@ def partial_update(self, request: HttpRequest, *_args, **_kwargs) -> Response:

return Response(serializer.data)

def get_queryset(self) -> QuerySet:
"""
Called to fetch the initial queryset, used to implement some of the more complex filters.
This provides the `permanent` and the `expires_gte` and `expires_lte` options.
"""
filter_permanent = self.request.query_params.get('permanent')
additional_filters = {}
if filter_permanent is not None:
additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true'

filter_expires_after = self.request.query_params.get('expires_after')
if filter_expires_after:
try:
additional_filters['expires_at__gte'] = datetime.fromisoformat(
filter_expires_after
)
except ValueError:
raise ValidationError({'expires_after': ['failed to convert to datetime']})

filter_expires_before = self.request.query_params.get('expires_before')
if filter_expires_before:
try:
additional_filters['expires_at__lte'] = datetime.fromisoformat(
filter_expires_before
)
except ValueError:
raise ValidationError({'expires_before': ['failed to convert to datetime']})

if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters:
if additional_filters['expires_at__gte'] > additional_filters['expires_at__lte']:
raise ValidationError({
'expires_before': ['cannot be after expires_after'],
'expires_after': ['cannot be before expires_before'],
})

if (
('expires_at__lte' in additional_filters or 'expires_at__gte' in additional_filters)
and 'expires_at__isnull' in additional_filters
and additional_filters['expires_at__isnull']
):
raise ValidationError({
'permanent': [
'cannot filter for permanent infractions at the'
' same time as expires_at or expires_before',
]
})

if filter_expires_before:
# Filter out permanent infractions specifically if we want ones that will expire
# before a given date
additional_filters['expires_at__isnull'] = False

filter_types = self.request.query_params.get('types')
if filter_types:
if self.request.query_params.get('type'):
raise ValidationError({
'types': ['you must provide only one of "type" or "types"'],
})
additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")]

return self.queryset.filter(**additional_filters)

@action(url_path='expanded', detail=False)
def list_expanded(self, *args, **kwargs) -> Response:
"""
Expand Down

0 comments on commit 7d81829

Please sign in to comment.