diff --git a/config/settings/base.py b/config/settings/base.py index 738eedf..5422192 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -62,6 +62,7 @@ # "django.contrib.humanize", # Handy template tags "django.contrib.admin", "django.forms", + "django_filters", ] THIRD_PARTY_APPS = [ "rest_framework", @@ -258,6 +259,7 @@ # ------------------------------------------------------------------------------- # django-rest-framework - https://www.django-rest-framework.org/api-guide/settings/ REST_FRAMEWORK = { + "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework.authentication.SessionAuthentication", "rest_framework.authentication.TokenAuthentication", diff --git a/requirements/base.txt b/requirements/base.txt index c2305fd..dceae5a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -18,6 +18,7 @@ django-timezone-field==4.0 # https://github.com/mfogel/django-timezone-field/ django-cors-headers==3.7.0 # https://github.com/adamchainz/django-cors-headers django-csp==3.7 # https://github.com/mozilla/django-csp whitenoise==5.2.0 # https://github.com/evansd/whitenoise +django-filter==2.4.0 # Django REST Framework djangorestframework==3.12.2 # https://github.com/encode/django-rest-framework diff --git a/tmh_registry/common/models/__init__.py b/tmh_registry/common/models/__init__.py new file mode 100644 index 0000000..57bde91 --- /dev/null +++ b/tmh_registry/common/models/__init__.py @@ -0,0 +1 @@ +from .time_stamp_mixin import TimeStampMixin diff --git a/tmh_registry/common/models/time_stamp_mixin.py b/tmh_registry/common/models/time_stamp_mixin.py new file mode 100644 index 0000000..155dad2 --- /dev/null +++ b/tmh_registry/common/models/time_stamp_mixin.py @@ -0,0 +1,9 @@ +from django.db.models import DateField, Model + + +class TimeStampMixin(Model): + created_at = DateField(auto_now_add=True) + updated_at = DateField(auto_now=True) + + class Meta: + abstract = True diff --git a/tmh_registry/registry/api/serializers.py b/tmh_registry/registry/api/serializers.py index b82a5c0..6001940 100644 --- a/tmh_registry/registry/api/serializers.py +++ b/tmh_registry/registry/api/serializers.py @@ -58,6 +58,7 @@ class Meta: fields = [ "id", "full_name", + "created_at", "national_id", "age", "day_of_birth", diff --git a/tmh_registry/registry/api/viewsets.py b/tmh_registry/registry/api/viewsets.py index 7f35a7a..20b9804 100644 --- a/tmh_registry/registry/api/viewsets.py +++ b/tmh_registry/registry/api/viewsets.py @@ -1,6 +1,11 @@ +from django.db.models import Q from django.utils.decorators import method_decorator +from django_filters import CharFilter, NumberFilter # pylint: disable=E0401 +from django_filters.rest_framework import FilterSet # pylint: disable=E0401 +from drf_yasg.openapi import IN_QUERY, TYPE_INTEGER, TYPE_STRING, Parameter from drf_yasg.utils import swagger_auto_schema from rest_framework import mixins, viewsets +from rest_framework.filters import OrderingFilter from rest_framework.viewsets import GenericViewSet from ..models import Hospital, Patient, PatientHospitalMapping @@ -14,24 +19,83 @@ class HospitalViewSet(viewsets.ReadOnlyModelViewSet): - queryset = Hospital.objects.all() serializer_class = HospitalSerializer +class PatientFilterSet(FilterSet): + hospital_id = NumberFilter( + method="filter_hospital", + label="Filter based on hospital", + ) + search_term = CharFilter( + method="filter_search_term", + label="Filter based on search_term", + ) + + def filter_hospital(self, queryset, name, value): + if value: + patient_ids = PatientHospitalMapping.objects.filter( + hospital_id=value + ).values_list("patient_id", flat=True) + queryset = Patient.objects.filter(id__in=patient_ids) + return queryset + + def filter_search_term(self, queryset, name, value): + if value: + queryset = Patient.objects.filter( + Q(full_name__icontains=value) | Q(national_id__iexact=value) + ) + return queryset + + class Meta: + model = Patient + fields = ["hospital_id"] + + @method_decorator( name="create", decorator=swagger_auto_schema( responses={201: ReadPatientSerializer(many=True)} ), ) +@method_decorator( + name="list", + decorator=swagger_auto_schema( + manual_parameters=[ + Parameter( + "ordering", + IN_QUERY, + description="Choose with which field you want to order with. Possible options: [full_name, created_at]", + type=TYPE_STRING, + ), + Parameter( + "hospital_id", + IN_QUERY, + description="Filter with patients of a specific hospital.", + type=TYPE_INTEGER, + ), + Parameter( + "search_term", + IN_QUERY, + description="Filter patients with search term. A patient will be returned if national id is an exact " + "match or full name is even partially matched.", + type=TYPE_INTEGER, + ), + ], + responses={200: ReadPatientSerializer(many=True)}, + ), +) class PatientViewSet( mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet, ): + filter_backends = [OrderingFilter] + ordering_fields = ("full_name", "created_at") + filterset_class = PatientFilterSet queryset = Patient.objects.all() def get_serializer_class(self): diff --git a/tmh_registry/registry/factories.py b/tmh_registry/registry/factories.py index c612745..0b4a894 100644 --- a/tmh_registry/registry/factories.py +++ b/tmh_registry/registry/factories.py @@ -23,6 +23,8 @@ class Meta: model = Patient full_name = LazyAttribute(lambda n: faker.name()) + created_at = LazyAttribute(lambda n: faker.date()) + updated_at = LazyAttribute(lambda n: faker.date()) national_id = LazyAttribute( lambda n: faker.numerify(text="####################") ) diff --git a/tmh_registry/registry/migrations/0011_auto_20211119_1518.py b/tmh_registry/registry/migrations/0011_auto_20211119_1518.py new file mode 100644 index 0000000..74c8648 --- /dev/null +++ b/tmh_registry/registry/migrations/0011_auto_20211119_1518.py @@ -0,0 +1,27 @@ +# Generated by Django 3.1.3 on 2021-11-19 15:18 + +from django.db import migrations, models +import django.utils.timezone + + +class Migration(migrations.Migration): + + dependencies = [ + ("registry", "0010_auto_20211021_1449"), + ] + + operations = [ + migrations.AddField( + model_name="patient", + name="created_at", + field=models.DateField( + auto_now_add=True, default=django.utils.timezone.now + ), + preserve_default=False, + ), + migrations.AddField( + model_name="patient", + name="updated_at", + field=models.DateField(auto_now=True), + ), + ] diff --git a/tmh_registry/registry/migrations/0012_auto_20211119_1656.py b/tmh_registry/registry/migrations/0012_auto_20211119_1656.py new file mode 100644 index 0000000..4469076 --- /dev/null +++ b/tmh_registry/registry/migrations/0012_auto_20211119_1656.py @@ -0,0 +1,23 @@ +# Generated by Django 3.1.3 on 2021-11-19 16:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("registry", "0011_auto_20211119_1518"), + ] + + operations = [ + migrations.AlterField( + model_name="patient", + name="gender", + field=models.CharField( + blank=True, + choices=[("Male", "Male"), ("Female", "Female")], + max_length=32, + null=True, + ), + ), + ] diff --git a/tmh_registry/registry/models.py b/tmh_registry/registry/models.py index 428f759..58d2b13 100644 --- a/tmh_registry/registry/models.py +++ b/tmh_registry/registry/models.py @@ -3,7 +3,8 @@ from django.db import models from django.db.models.enums import TextChoices -from ..users.models import MedicalPersonnel +from tmh_registry.common.models import TimeStampMixin +from tmh_registry.users.models import MedicalPersonnel class Hospital(models.Model): @@ -14,10 +15,10 @@ def __str__(self): return self.name -class Patient(models.Model): +class Patient(TimeStampMixin): class Gender(TextChoices): - MALE = ("MALE", "Male") - FEMALE = ("FEMALE", "Female") + MALE = ("Male", "Male") + FEMALE = ("Female", "Female") full_name = models.CharField(max_length=255) national_id = models.CharField( diff --git a/tmh_registry/registry/tests/api/viewsets/test_patients.py b/tmh_registry/registry/tests/api/viewsets/test_patients.py index 10e76d5..810ac63 100644 --- a/tmh_registry/registry/tests/api/viewsets/test_patients.py +++ b/tmh_registry/registry/tests/api/viewsets/test_patients.py @@ -19,7 +19,7 @@ PatientFactory, PatientHospitalMappingFactory, ) -from ....models import Patient, PatientHospitalMapping +from ....models import PatientHospitalMapping @mark.registry @@ -31,7 +31,11 @@ def setUpClass(cls) -> None: super(TestPatientsViewSet, cls).setUpClass() cls.hospital = HospitalFactory() - cls.patient = PatientFactory() + + cls.patient = PatientFactory(full_name="John Doe") + cls.patient.created_at = datetime.date(year=2021, month=4, day=11) + cls.patient.save() + cls.patient_hospital_mapping = PatientHospitalMapping.objects.create( patient=cls.patient, hospital=cls.hospital ) @@ -53,7 +57,7 @@ def get_patient_test_data(self): "day_of_birth": 3, "month_of_birth": 10, "year_of_birth": 1994, - "gender": Patient.Gender.FEMALE, + "gender": "Female", "phone_1": 234633241, "phone_2": 324362141, "address": "16 Test Street, Test City, Test Country", @@ -87,6 +91,81 @@ def test_get_patients_list_successful(self): response.data["results"][0]["hospital_mappings"][0]["hospital_id"], ) + def test_get_patients_list_with_hospital_id_successful(self): + response = self.client.get( + f"/api/v1/patients/?hospital_id={self.hospital.id}", format="json" + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(1, response.data["count"]) + self.assertEqual(self.patient.id, response.data["results"][0]["id"]) + + def test_get_patients_list_with_full_name_search_term_successful(self): + fullname_search_term = self.patient.full_name[:-3] + response = self.client.get( + f"/api/v1/patients/?search_term={fullname_search_term}", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(1, response.data["count"]) + + def test_get_patients_list_with_national_id_search_term_successful(self): + response = self.client.get( + f"/api/v1/patients/?search_term={self.patient.national_id}", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(1, response.data["count"]) + + def test_get_patients_list_full_name_ordering(self): + patient2 = PatientFactory(full_name="Zachary Unknown") + + response = self.client.get( + "/api/v1/patients/?ordering=full_name", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(2, response.data["count"]) + + self.assertEqual(self.patient.id, response.data["results"][0]["id"]) + self.assertEqual(patient2.id, response.data["results"][1]["id"]) + + # descending + response = self.client.get( + "/api/v1/patients/?ordering=-full_name", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(2, response.data["count"]) + + self.assertEqual(patient2.id, response.data["results"][0]["id"]) + self.assertEqual(self.patient.id, response.data["results"][1]["id"]) + + def test_get_patients_list_created_at_ordering(self): + patient2 = PatientFactory() + patient2.created_at = datetime.date(year=2021, month=11, day=4) + patient2.save() + + response = self.client.get( + "/api/v1/patients/?ordering=created_at", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(2, response.data["count"]) + + self.assertEqual(self.patient.id, response.data["results"][0]["id"]) + self.assertEqual(patient2.id, response.data["results"][1]["id"]) + + # descending + response = self.client.get( + "/api/v1/patients/?ordering=-created_at", + format="json", + ) + self.assertEqual(HTTP_200_OK, response.status_code) + self.assertEqual(2, response.data["count"]) + + self.assertEqual(patient2.id, response.data["results"][0]["id"]) + self.assertEqual(self.patient.id, response.data["results"][1]["id"]) + def test_get_patients_list_unauthorized(self): self.client = APIClient() response = self.client.get("/api/v1/patients/", format="json") @@ -217,6 +296,7 @@ def test_create_patients_successful(self): datetime.datetime.today().year - data["year_of_birth"], response.data["age"], ) + self.assertEqual("Female", response.data["gender"]) self.assertEqual(data["phone_1"], response.data["phone_1"]) self.assertEqual(data["phone_2"], response.data["phone_2"])