diff --git a/requirements/local.txt b/requirements/local.txt index 7cc8276..7b5518c 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -14,6 +14,7 @@ pytest-sugar==0.9.4 # https://github.com/Frozenball/pytest-sugar mock==4.0.3 # https://pypi.org/project/mock/ freezegun==1.1.0 # https://pypi.org/project/freezegun/0.1.11/ faker==8.5.1 # https://github.com/joke2k/faker +parameterized==0.8.1 # https://github.com/wolever/parameterized # Code quality # ------------------------------------------------------------------------------ diff --git a/tmh_registry/registry/api/serializers.py b/tmh_registry/registry/api/serializers.py index 6001940..26a5c1b 100644 --- a/tmh_registry/registry/api/serializers.py +++ b/tmh_registry/registry/api/serializers.py @@ -5,6 +5,7 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField from ...users.api.serializers import MedicalPersonnelSerializer +from ...users.models import MedicalPersonnel from ..models import Episode, Hospital, Patient, PatientHospitalMapping @@ -180,7 +181,7 @@ class PatientHospitalMappingReadSerializer(ModelSerializer): class Meta: model = PatientHospitalMapping - fields = ["patient", "hospital", "patient_hospital"] + fields = ["patient", "hospital", "patient_hospital_id"] class PatientHospitalMappingWriteSerializer(ModelSerializer): @@ -191,6 +192,10 @@ class Meta: model = PatientHospitalMapping fields = ["patient_id", "hospital_id", "patient_hospital_id"] + def to_representation(self, instance): + serializer = PatientHospitalMappingReadSerializer(instance) + return serializer.data + def create(self, validated_data): validated_data["patient_id"] = validated_data["patient_id"].id validated_data["hospital_id"] = validated_data["hospital_id"].id @@ -233,3 +238,102 @@ def create(self, validated_data): ) return new_mapping + + +class EpisodeReadSerializer(ModelSerializer): + patient_hospital_mapping = PatientHospitalMappingReadSerializer() + surgeons = MedicalPersonnelSerializer(many=True) + + class Meta: + model = Episode + fields = [ + "patient_hospital_mapping", + "created", + "surgery_date", + "episode_type", + "surgeons", + "comments", + "cepod", + "side", + "occurence", + "type", + "complexity", + "mesh_type", + "anaesthetic_type", + "diathermy_used", + ] + + +class EpisodeWriteSerializer(ModelSerializer): + patient_id = PrimaryKeyRelatedField( + write_only=True, queryset=Patient.objects.all() + ) + hospital_id = PrimaryKeyRelatedField( + write_only=True, queryset=Hospital.objects.all() + ) + surgeon_ids = PrimaryKeyRelatedField( + write_only=True, many=True, queryset=MedicalPersonnel.objects.all() + ) + + class Meta: + model = Episode + fields = [ + "patient_id", + "hospital_id", + "surgery_date", + "episode_type", + "surgeon_ids", + "comments", + "cepod", + "side", + "occurence", + "type", + "complexity", + "mesh_type", + "anaesthetic_type", + "diathermy_used", + ] + + def to_representation(self, instance): + serializer = EpisodeReadSerializer(instance) + return serializer.data + + def create(self, validated_data): + patient = validated_data["patient_id"] + hospital = validated_data["hospital_id"] + surgeons = validated_data["surgeon_ids"] + + patient_hospital_mapping = PatientHospitalMapping.objects.filter( + patient_id=patient.id, + hospital_id=hospital.id, + ).first() + if patient_hospital_mapping is None: + raise ValidationError( + { + "error": "PatientHospitalMapping for patient_id {patient_id} and hospital_id {hospital_id} " + "does not exist.".format( + patient_id=patient.id, + hospital_id=hospital.id, + ) + } + ) + + episode = Episode.objects.create( + patient_hospital_mapping=patient_hospital_mapping, + surgery_date=validated_data["surgery_date"], + episode_type=validated_data["episode_type"], + comments=validated_data["comments"], + cepod=validated_data["cepod"], + side=validated_data["side"], + occurence=validated_data["occurence"], + type=validated_data["type"], + complexity=validated_data["complexity"], + mesh_type=validated_data["mesh_type"], + anaesthetic_type=validated_data["anaesthetic_type"], + diathermy_used=validated_data["diathermy_used"], + ) + + if surgeons: + episode.surgeons.set(surgeons) + + return episode diff --git a/tmh_registry/registry/api/viewsets.py b/tmh_registry/registry/api/viewsets.py index 6ddb55b..6f997b9 100644 --- a/tmh_registry/registry/api/viewsets.py +++ b/tmh_registry/registry/api/viewsets.py @@ -9,11 +9,14 @@ from drf_yasg.openapi import IN_QUERY, TYPE_INTEGER, TYPE_STRING, Parameter from drf_yasg.utils import swagger_auto_schema from rest_framework import mixins, viewsets +from rest_framework.mixins import CreateModelMixin from rest_framework.viewsets import GenericViewSet -from ..models import Hospital, Patient, PatientHospitalMapping +from ..models import Episode, Hospital, Patient, PatientHospitalMapping from .serializers import ( CreatePatientSerializer, + EpisodeReadSerializer, + EpisodeWriteSerializer, HospitalSerializer, PatientHospitalMappingReadSerializer, PatientHospitalMappingWriteSerializer, @@ -91,7 +94,7 @@ class Meta: type=TYPE_INTEGER, ), ], - responses={200: ReadPatientSerializer(many=True)}, + responses={200: ReadPatientSerializer()}, ), ) class PatientViewSet( @@ -112,7 +115,13 @@ def get_serializer_class(self): raise NotImplementedError -class PatientHospitalMappingViewset(mixins.CreateModelMixin, GenericViewSet): +@method_decorator( + name="create", + decorator=swagger_auto_schema( + responses={201: PatientHospitalMappingReadSerializer()} + ), +) +class PatientHospitalMappingViewset(CreateModelMixin, GenericViewSet): queryset = PatientHospitalMapping.objects.all() def get_serializer_class(self): @@ -122,3 +131,19 @@ def get_serializer_class(self): return PatientHospitalMappingWriteSerializer raise NotImplementedError + + +@method_decorator( + name="create", + decorator=swagger_auto_schema(responses={201: EpisodeReadSerializer()}), +) +class EpisodeViewset(CreateModelMixin, GenericViewSet): + queryset = Episode.objects.all() + + def get_serializer_class(self): + if self.action in ["list", "retrieve"]: + return EpisodeReadSerializer + if self.action == "create": + return EpisodeWriteSerializer + + raise NotImplementedError diff --git a/tmh_registry/registry/tests/api/viewsets/test_episodes.py b/tmh_registry/registry/tests/api/viewsets/test_episodes.py new file mode 100644 index 0000000..c3fe0dd --- /dev/null +++ b/tmh_registry/registry/tests/api/viewsets/test_episodes.py @@ -0,0 +1,123 @@ +from datetime import date + +from django.test import TestCase +from parameterized import parameterized +from rest_framework.authtoken.models import Token +from rest_framework.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST +from rest_framework.test import APIClient + +from tmh_registry.registry.factories import ( + HospitalFactory, + PatientFactory, + PatientHospitalMappingFactory, +) +from tmh_registry.registry.models import Episode, PatientHospitalMapping +from tmh_registry.users.factories import MedicalPersonnelFactory + + +class TestEpisodesViewSet(TestCase): + @classmethod + def setUpClass(cls) -> None: + super(TestEpisodesViewSet, cls).setUpClass() + + cls.hospital = HospitalFactory() + + cls.patient = PatientFactory(full_name="John Doe") + cls.patient.created_at = date(year=2021, month=4, day=11) + cls.patient.save() + + cls.medical_personnel = MedicalPersonnelFactory() + cls.token = Token.objects.create(user=cls.medical_personnel.user) + + def get_episode_test_data(self): + return { + "patient_id": self.patient.id, + "hospital_id": self.hospital.id, + "surgery_date": "2021-10-12", + "episode_type": Episode.EpisodeChoices.UMBILICAL.value, + "surgeon_ids": [self.medical_personnel.id], + "comments": "A random comment", + "cepod": Episode.CepodChoices.PLANNED.value, + "side": Episode.SideChoices.LEFT.value, + "occurence": Episode.OccurenceChoices.RECURRENT.value, + "type": Episode.TypeChoices.INDIRECT.value, + "complexity": Episode.ComplexityChoices.INCARCERATED.value, + "mesh_type": Episode.MeshTypeChoices.TNMHP.value, + "anaesthetic_type": Episode.AnaestheticChoices.SPINAL.value, + "diathermy_used": True, + } + + def setUp(self) -> None: + self.client = APIClient() + self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key) + + self.patient_hospital_mapping = PatientHospitalMappingFactory( + patient=self.patient, hospital=self.hospital + ) + + def test_create_episode_when_no_patient_hospital_mapping_exists(self): + PatientHospitalMapping.objects.all().delete() + + data = self.get_episode_test_data() + response = self.client.post( + "/api/v1/episodes/", data=data, format="json" + ) + + self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code) + + @parameterized.expand( + [ + ("episode_type", "hIaTuS"), + ("cepod", "WRONG_OPTION"), + ("side", "WRONG_OPTION"), + ("occurence", "WRONG_OPTION"), + ("type", "WRONG_OPTION"), + ("complexity", "WRONG_OPTION"), + ("mesh_type", "WRONG_OPTION"), + ("anaesthetic_type", "WRONG_OPTION"), + ] + ) + def test_with_non_acceptable_value_for_a_field(self, field, value): + data = self.get_episode_test_data() + data[field] = value + + response = self.client.post( + "/api/v1/episodes/", data=data, format="json" + ) + + self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code) + + def test_create_episode_successful(self): + data = self.get_episode_test_data() + response = self.client.post( + "/api/v1/episodes/", data=data, format="json" + ) + + self.assertEqual(HTTP_201_CREATED, response.status_code) + + self.assertEqual( + response.data["patient_hospital_mapping"]["patient_hospital_id"], + self.patient_hospital_mapping.patient_hospital_id, + ) + self.assertEqual(response.data["surgery_date"], data["surgery_date"]) + self.assertEqual(response.data["episode_type"], data["episode_type"]) + + self.assertEqual(len(response.data["surgeons"]), 1) + self.assertEqual( + response.data["surgeons"][0]["user"]["email"], + self.medical_personnel.user.email, + ) + + self.assertEqual(response.data["comments"], data["comments"]) + self.assertEqual(response.data["cepod"], data["cepod"]) + self.assertEqual(response.data["side"], data["side"]) + self.assertEqual(response.data["occurence"], data["occurence"]) + self.assertEqual(response.data["type"], data["type"]) + self.assertEqual(response.data["complexity"], data["complexity"]) + self.assertEqual(response.data["mesh_type"], data["mesh_type"]) + self.assertEqual( + response.data["anaesthetic_type"], data["anaesthetic_type"] + ) + self.assertEqual( + response.data["diathermy_used"], data["diathermy_used"] + ) diff --git a/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py b/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py index 4507e15..b044edf 100644 --- a/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py +++ b/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py @@ -32,8 +32,8 @@ def test_create_successful(self): ) self.assertEqual(HTTP_201_CREATED, response.status_code) - self.assertEqual(response.data["patient_id"], patient.id) - self.assertEqual(response.data["hospital_id"], hospital.id) + self.assertEqual(response.data["patient"]["id"], patient.id) + self.assertEqual(response.data["hospital"]["id"], hospital.id) self.assertEqual( response.data["patient_hospital_id"], data["patient_hospital_id"] ) diff --git a/tmh_registry/registry/urls.py b/tmh_registry/registry/urls.py index 36f80c0..e5af3e5 100644 --- a/tmh_registry/registry/urls.py +++ b/tmh_registry/registry/urls.py @@ -2,6 +2,7 @@ from rest_framework.routers import DefaultRouter from .api.viewsets import ( + EpisodeViewset, HospitalViewSet, PatientHospitalMappingViewset, PatientViewSet, @@ -11,6 +12,7 @@ router.register(r"hospitals", HospitalViewSet) router.register(r"patients", PatientViewSet) router.register(r"patient-hospital-mappings", PatientHospitalMappingViewset) +router.register(r"episodes", EpisodeViewset) urlpatterns = [ path("", include(router.urls)),