Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new infraction filters for the infraction rescheduler #510

Merged
174 changes: 168 additions & 6 deletions pydis_site/apps/api/tests/test_infractions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from datetime import datetime as dt, timedelta, timezone
from unittest.mock import patch
from urllib.parse import quote
Expand All @@ -16,7 +17,7 @@ def setUp(self):
self.client.force_authenticate(user=None)

def test_detail_lookup_returns_401(self):
url = reverse('bot:infraction-detail', args=(5,), host='api')
url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.get(url)

self.assertEqual(response.status_code, 401)
Expand All @@ -34,7 +35,7 @@ def test_create_returns_401(self):
self.assertEqual(response.status_code, 401)

def test_partial_update_returns_401(self):
url = reverse('bot:infraction-detail', args=(5,), host='api')
url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.patch(url, data={'reason': 'Have a nice day.'})

self.assertEqual(response.status_code, 401)
Expand All @@ -44,7 +45,7 @@ class InfractionTests(APISubdomainTestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create(
id=5,
id=6,
name='james',
discriminator=1,
)
Expand All @@ -64,6 +65,30 @@ def setUpTestData(cls):
reason='James is an ass, and we won\'t be working with him again.',
active=False
)
cls.mute_permanent = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='mute',
reason='He has a filthy mouth and I am his soap.',
active=True,
expires_at=None
)
cls.superstar_expires_soon = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='superstar',
reason='This one doesn\'t matter anymore.',
active=True,
expires_at=datetime.datetime.utcnow() + datetime.timedelta(hours=5)
)
cls.voiceban_expires_later = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='voice_ban',
reason='Jet engine mic',
active=True,
expires_at=datetime.datetime.utcnow() + datetime.timedelta(days=5)
)

def test_list_all(self):
"""Tests the list-view, which should be ordered by inserted_at (newest first)."""
Expand All @@ -73,9 +98,12 @@ def test_list_all(self):
self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 2)
self.assertEqual(infractions[0]['id'], self.ban_inactive.id)
self.assertEqual(infractions[1]['id'], self.ban_hidden.id)
self.assertEqual(len(infractions), 5)
self.assertEqual(infractions[0]['id'], self.voiceban_expires_later.id)
self.assertEqual(infractions[1]['id'], self.superstar_expires_soon.id)
self.assertEqual(infractions[2]['id'], self.mute_permanent.id)
self.assertEqual(infractions[3]['id'], self.ban_inactive.id)
self.assertEqual(infractions[4]['id'], self.ban_hidden.id)

def test_filter_search(self):
url = reverse('bot:infraction-list', host='api')
Expand All @@ -98,6 +126,140 @@ def test_filter_field(self):
self.assertEqual(len(infractions), 1)
self.assertEqual(infractions[0]['id'], self.ban_hidden.id)

def test_filter_permanent_false(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=mute&permanent=false')

self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 0)

def test_filter_permanent_true(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=mute&permanent=true')

self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(infractions[0]['id'], self.mute_permanent.id)

def test_filter_after(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?type=superstar&expires_after={target_time.isoformat()}')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 0)

def test_filter_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?type=superstar&expires_before={target_time.isoformat()}')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 1)
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)

def test_filter_after_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?expires_after=gibberish')

self.assertEqual(response.status_code, 400)
self.assertEqual(list(response.json())[0], "expires_after")

def test_filter_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?expires_before=000000000')

self.assertEqual(response.status_code, 400)
self.assertEqual(list(response.json())[0], "expires_before")

def test_after_before_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=4)
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
response = self.client.get(
f'{url}?expires_before={target_time_late.isoformat()}'
f'&expires_after={target_time.isoformat()}'
)

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 1)
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)

def test_after_after_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
response = self.client.get(
f'{url}?expires_before={target_time.isoformat()}'
f'&expires_after={target_time_late.isoformat()}'
)

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertIn("expires_before", errors)
self.assertIn("expires_after", errors)

def test_permanent_after_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?permanent=true&expires_after={target_time.isoformat()}')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("permanent", errors[0])

def test_permanent_before_invalid(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
response = self.client.get(f'{url}?permanent=true&expires_before={target_time.isoformat()}')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("permanent", errors[0])

def test_nonpermanent_before(self):
url = reverse('bot:infraction-list', host='api')
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
response = self.client.get(
f'{url}?permanent=false&expires_before={target_time.isoformat()}'
)

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 1)
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)

def test_filter_manytypes(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?types=mute,ban')

self.assertEqual(response.status_code, 200)
infractions = response.json()
self.assertEqual(len(infractions), 3)

def test_types_type_invalid(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?types=mute,ban&type=superstar')

self.assertEqual(response.status_code, 400)
errors = list(response.json())
self.assertEqual("types", errors[0])

def test_sort_expiresby(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?ordering=expires_at&permanent=false')
self.assertEqual(response.status_code, 200)
infractions = response.json()

self.assertEqual(len(infractions), 3)
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
self.assertEqual(infractions[1]['id'], self.voiceban_expires_later.id)
self.assertEqual(infractions[2]['id'], self.ban_hidden.id)

def test_returns_empty_for_no_match(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=ban&search=poop')
Expand Down
73 changes: 73 additions & 0 deletions pydis_site/apps/api/viewsets/bot/infraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import datetime

from django.db.models import QuerySet
from django.http.request import HttpRequest
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.decorators import action
Expand Down Expand Up @@ -43,10 +46,17 @@ class InfractionViewSet(
- **offset** `int`: the initial index from which to return the results (default 0)
- **search** `str`: regular expression applied to the infraction's reason
- **type** `str`: the type of the infraction
- **types** `str`: comma separated sequence of types to filter for
- **user__id** `int`: snowflake of the user to which the infraction was applied
- **ordering** `str`: comma-separated sequence of fields to order the returned results
- **permanent** `bool`: whether or not to retrieve permanent infractions (default True)
- **expires_after** `isodatetime`: the earliest expires_at time to return infractions for
- **expires_before** `isodatetime`: the latest expires_at time to return infractions for

Invalid query parameters are ignored.
Only one of `type` and `types` may be provided. If both `expires_before` and `expires_after`
are provided, `expires_after` must come after `expires_before`.
If `permanent` is provided and true, `expires_before` and `expires_after` must not be provided.

#### Response format
Response is paginated but the result is returned without any pagination metadata.
Expand Down Expand Up @@ -156,6 +166,69 @@ def partial_update(self, request: HttpRequest, *_args, **_kwargs) -> Response:

return Response(serializer.data)

def get_queryset(self) -> QuerySet:
"""
Called to fetch the initial queryset, used to implement some of the more complex filters.

This provides the `permanent` and the `expires_gte` and `expires_lte` options.
"""
filter_permanent = self.request.query_params.get('permanent')
additional_filters = {}
if filter_permanent is not None:
additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true'

filter_expires_after = self.request.query_params.get('expires_after')
if filter_expires_after:
try:
additional_filters['expires_at__gte'] = datetime.fromisoformat(
filter_expires_after
)
except ValueError:
raise ValidationError({'expires_after': ['failed to convert to datetime']})

filter_expires_before = self.request.query_params.get('expires_before')
if filter_expires_before:
try:
additional_filters['expires_at__lte'] = datetime.fromisoformat(
filter_expires_before
)
except ValueError:
raise ValidationError({'expires_before': ['failed to convert to datetime']})

if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters:
if additional_filters['expires_at__gte'] > additional_filters['expires_at__lte']:
raise ValidationError({
'expires_before': ['cannot be after expires_after'],
'expires_after': ['cannot be before expires_before'],
})

if (
('expires_at__lte' in additional_filters or 'expires_at__gte' in additional_filters)
and 'expires_at__isnull' in additional_filters
and additional_filters['expires_at__isnull']
):
raise ValidationError({
'permanent': [
'cannot filter for permanent infractions at the'
' same time as expires_at or expires_before',
]
})

if filter_expires_before:
# Filter out permanent infractions specifically if we want ones that will expire
# before a given date
additional_filters['expires_at__isnull'] = False

filter_types = self.request.query_params.get('types')
if filter_types:
if self.request.query_params.get('type'):
raise ValidationError({
'types': ['you must provide only one of "type" or "types"'],
})
additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")]

return self.queryset.filter(**additional_filters)

@action(url_path='expanded', detail=False)
def list_expanded(self, *args, **kwargs) -> Response:
"""
Expand Down