diff --git a/Makefile b/Makefile index 9887eda..3648cf0 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,8 @@ help: @echo " build to build the stack in docker compose." @echo " run to run the stack in docker compose." @echo " migrate to run the DB migrations in docker compose." - @echo " test to run the stack tests in docker compose." + @echo " test [target=] to run the stack tests in docker compose. Optionally specify a target." + @echo " test-one test_name= to run a specific stack test in docker compose. Optionally, you can specify the test's file as well." @echo " cover to print the code coverage in docker compose." @echo " teardown to tear down the stack in docker compose." @@ -27,7 +28,10 @@ migrate: docker-compose -f ${COMPOSE_ENV}.yml exec -T django bash -c "python manage.py makemigrations && python manage.py migrate" test: - docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest --disable-pytest-warnings; + docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest ${target} --disable-pytest-warnings; + +test-one: + docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest ${file} -k ${test_name} --disable-pytest-warnings; cover: docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage report diff --git a/tmh_registry/registry/api/serializers.py b/tmh_registry/registry/api/serializers.py index 62b5835..20aac93 100644 --- a/tmh_registry/registry/api/serializers.py +++ b/tmh_registry/registry/api/serializers.py @@ -1,5 +1,6 @@ from rest_framework.exceptions import ValidationError -from rest_framework.fields import IntegerField, SerializerMethodField +from rest_framework.fields import IntegerField +from rest_framework.relations import PrimaryKeyRelatedField from rest_framework.serializers import ModelSerializer from ..models import Hospital, Patient, PatientHospitalMapping @@ -11,30 +12,15 @@ class Meta: fields = ["id", "name", "address"] -class ReadPatientSerializer(ModelSerializer): - age = IntegerField(allow_null=True) - hospitals = SerializerMethodField() +class PatientHospitalMappingPatientSerializer(ModelSerializer): + class Meta: + model = PatientHospitalMapping + fields = ["patient_hospital_id", "hospital_id"] - def get_hospitals(self, obj): - hospitals = Hospital.objects.filter( - id__in=obj.hospital_mappings.all().values_list( - "hospital_id", flat=True - ) - ) - hospital_data = HospitalSerializer(hospitals, many=True).data - - # enrich with patient_hospital_id - idx = 0 - for hospital in hospital_data: - patient_hospital_id = PatientHospitalMapping.objects.get( - hospital_id=hospital["id"], patient_id=obj.id - ).patient_hospital_id - hospital_data[idx].update( - {"patient_hospital_id": patient_hospital_id} - ) - idx += 1 - return hospital_data +class ReadPatientSerializer(ModelSerializer): + age = IntegerField(allow_null=True) + hospital_mappings = PatientHospitalMappingPatientSerializer(many=True) class Meta: model = Patient @@ -50,7 +36,7 @@ class Meta: "phone_1", "phone_2", "address", - "hospitals", + "hospital_mappings", ] def to_representation(self, instance): @@ -153,3 +139,64 @@ def create(self, validated_data): ) return new_patient + + +class PatientHospitalMappingReadSerializer(ModelSerializer): + patient = ReadPatientSerializer() + hospital = HospitalSerializer() + + class Meta: + model = PatientHospitalMapping + fields = ["patient", "hospital", "patient_hospital"] + + +class PatientHospitalMappingWriteSerializer(ModelSerializer): + patient_id = PrimaryKeyRelatedField(queryset=Patient.objects.all()) + hospital_id = PrimaryKeyRelatedField(queryset=Hospital.objects.all()) + + class Meta: + model = PatientHospitalMapping + fields = ["patient_id", "hospital_id", "patient_hospital_id"] + + def create(self, validated_data): + validated_data["patient_id"] = validated_data["patient_id"].id + validated_data["hospital_id"] = validated_data["hospital_id"].id + + existing_mapping = PatientHospitalMapping.objects.filter( + patient_id=validated_data["patient_id"], + hospital_id=validated_data["hospital_id"], + ) + if existing_mapping.exists(): + raise ValidationError( + { + "error": "PatientHospitalMapping for patient_id {patient_id} and hospital_id {hospital_id} " + "already exists!".format( + patient_id=validated_data["patient_id"], + hospital_id=validated_data["hospital_id"], + ) + } + ) + + existing_patient_hospital_id = PatientHospitalMapping.objects.filter( + patient_hospital_id=validated_data["patient_hospital_id"], + hospital_id=validated_data["hospital_id"], + ) + if existing_patient_hospital_id.exists(): + raise ValidationError( + { + "error": "Patient Hospital ID {patient_hospital_id} already exists for another patient in " + "this hospital".format( + patient_hospital_id=validated_data[ + "patient_hospital_id" + ] + ) + } + ) + + new_mapping = PatientHospitalMapping.objects.create( + patient_hospital_id=validated_data["patient_hospital_id"], + hospital_id=validated_data["hospital_id"], + patient_id=validated_data["patient_id"], + ) + + return new_mapping diff --git a/tmh_registry/registry/api/viewsets.py b/tmh_registry/registry/api/viewsets.py index 523b49a..7f35a7a 100644 --- a/tmh_registry/registry/api/viewsets.py +++ b/tmh_registry/registry/api/viewsets.py @@ -3,10 +3,12 @@ from rest_framework import mixins, viewsets from rest_framework.viewsets import GenericViewSet -from ..models import Hospital, Patient +from ..models import Hospital, Patient, PatientHospitalMapping from .serializers import ( CreatePatientSerializer, HospitalSerializer, + PatientHospitalMappingReadSerializer, + PatientHospitalMappingWriteSerializer, ReadPatientSerializer, ) @@ -39,3 +41,15 @@ def get_serializer_class(self): return CreatePatientSerializer raise NotImplementedError + + +class PatientHospitalMappingViewset(mixins.CreateModelMixin, GenericViewSet): + queryset = PatientHospitalMapping.objects.all() + + def get_serializer_class(self): + if self.action in ["list", "retrieve"]: + return PatientHospitalMappingReadSerializer + if self.action == "create": + return PatientHospitalMappingWriteSerializer + + raise NotImplementedError diff --git a/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py b/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py new file mode 100644 index 0000000..4507e15 --- /dev/null +++ b/tmh_registry/registry/tests/api/viewsets/test_patient_hospital_mappings.py @@ -0,0 +1,86 @@ +from django.test import TestCase +from rest_framework.authtoken.models import Token +from rest_framework.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST +from rest_framework.test import APIClient + +from tmh_registry.users.factories import MedicalPersonnelFactory + +from ....factories import ( + HospitalFactory, + PatientFactory, + PatientHospitalMappingFactory, +) + + +class TestPatientHospitalMappingViewset(TestCase): + def setUp(self) -> None: + self.token = Token.objects.create(user=MedicalPersonnelFactory().user) + self.client = APIClient() + self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key) + + def test_create_successful(self): + patient = PatientFactory() + hospital = HospitalFactory() + data = { + "patient_id": patient.id, + "hospital_id": hospital.id, + "patient_hospital_id": "blabla", + } + + response = self.client.post( + "/api/v1/patient-hospital-mappings/", data=data, format="json" + ) + + self.assertEqual(HTTP_201_CREATED, response.status_code) + self.assertEqual(response.data["patient_id"], patient.id) + self.assertEqual(response.data["hospital_id"], hospital.id) + self.assertEqual( + response.data["patient_hospital_id"], data["patient_hospital_id"] + ) + + def test_create_when_mapping_already_exists(self): + mapping = PatientHospitalMappingFactory() + data = { + "patient_id": mapping.patient.id, + "hospital_id": mapping.hospital.id, + "patient_hospital_id": mapping.patient_hospital_id, + } + + response = self.client.post( + "/api/v1/patient-hospital-mappings/", data=data, format="json" + ) + + self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code) + + def test_create_when_mapping_already_exists_with_different_patient_hospital_id( + self, + ): + mapping = PatientHospitalMappingFactory() + data = { + "patient_id": mapping.patient.id, + "hospital_id": mapping.hospital.id, + "patient_hospital_id": "whatever", + } + + response = self.client.post( + "/api/v1/patient-hospital-mappings/", data=data, format="json" + ) + + self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code) + + def test_create_when_patient_hospital_id_already_exists_for_another_patient( + self, + ): + mapping = PatientHospitalMappingFactory() + patient = PatientFactory() + data = { + "patient_id": patient.id, + "hospital_id": mapping.hospital.id, + "patient_hospital_id": mapping.patient_hospital_id, + } + + response = self.client.post( + "/api/v1/patient-hospital-mappings/", data=data, format="json" + ) + + self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code) diff --git a/tmh_registry/registry/tests/api/viewsets/test_patients.py b/tmh_registry/registry/tests/api/viewsets/test_patients.py index eedd197..f7387a0 100644 --- a/tmh_registry/registry/tests/api/viewsets/test_patients.py +++ b/tmh_registry/registry/tests/api/viewsets/test_patients.py @@ -69,10 +69,18 @@ def test_get_patients_list_successful(self): self.assertEqual(self.patient.id, response.data["results"][0]["id"]) self.assertNotIn("hospital_id", response.data["results"][0]) - self.assertEqual(1, len(response.data["results"][0]["hospitals"])) + self.assertEqual( + 1, len(response.data["results"][0]["hospital_mappings"]) + ) self.assertEqual( self.mapping.patient_hospital_id, - response.data["results"][0]["hospitals"][0]["patient_hospital_id"], + response.data["results"][0]["hospital_mappings"][0][ + "patient_hospital_id" + ], + ) + self.assertEqual( + self.hospital.id, + response.data["results"][0]["hospital_mappings"][0]["hospital_id"], ) def test_get_patients_list_unauthorized(self): @@ -112,16 +120,19 @@ def test_get_patients_detail_successful(self): self.assertEqual(HTTP_200_OK, response.status_code) self.assertEqual(self.patient.id, response.data["id"]) - self.assertEqual(1, len(response.data["hospitals"])) + self.assertEqual(1, len(response.data["hospital_mappings"])) patient_hospital_id = PatientHospitalMapping.objects.get( hospital=self.hospital.id, patient=self.patient.id ).patient_hospital_id - self.assertEqual(self.hospital.id, response.data["hospitals"][0]["id"]) + self.assertEqual( + self.hospital.id, + response.data["hospital_mappings"][0]["hospital_id"], + ) self.assertEqual( patient_hospital_id, - response.data["hospitals"][0]["patient_hospital_id"], + response.data["hospital_mappings"][0]["patient_hospital_id"], ) def test_get_patients_detail_unauthorized(self): @@ -165,14 +176,15 @@ def test_get_patients_detail_with_multiple_hospitals(self): self.assertEqual(HTTP_200_OK, response.status_code) self.assertEqual(self.patient.id, response.data["id"]) - self.assertEqual(2, len(response.data["hospitals"])) + self.assertEqual(2, len(response.data["hospital_mappings"])) - for hospital in response.data["hospitals"]: + for hospital_mapping in response.data["hospital_mappings"]: patient_hospital_id = PatientHospitalMapping.objects.get( - hospital_id=hospital["id"], patient_id=self.patient.id + hospital_id=hospital_mapping["hospital_id"], + patient_id=self.patient.id, ).patient_hospital_id self.assertEqual( - patient_hospital_id, hospital["patient_hospital_id"] + patient_hospital_id, hospital_mapping["patient_hospital_id"] ) ######################## @@ -195,11 +207,14 @@ def test_create_patients_successful(self): self.assertEqual(data["phone_1"], response.data["phone_1"]) self.assertEqual(data["phone_2"], response.data["phone_2"]) - self.assertEqual(1, len(response.data["hospitals"])) - self.assertEqual(self.hospital.id, response.data["hospitals"][0]["id"]) + self.assertEqual(1, len(response.data["hospital_mappings"])) + self.assertEqual( + self.hospital.id, + response.data["hospital_mappings"][0]["hospital_id"], + ) self.assertEqual( data["patient_hospital_id"], - response.data["hospitals"][0]["patient_hospital_id"], + response.data["hospital_mappings"][0]["patient_hospital_id"], ) self.assertEqual( 1, @@ -274,7 +289,7 @@ def test_create_patients_without_year_of_birth_but_with_age(self): ) self.assertEqual(data["phone_1"], response.data["phone_1"]) self.assertEqual(data["phone_2"], response.data["phone_2"]) - self.assertEqual(1, len(response.data["hospitals"])) + self.assertEqual(1, len(response.data["hospital_mappings"])) def test_create_patients_without_year_of_birth_and_age(self): data = self.get_patient_test_data() diff --git a/tmh_registry/registry/urls.py b/tmh_registry/registry/urls.py index d390687..36f80c0 100644 --- a/tmh_registry/registry/urls.py +++ b/tmh_registry/registry/urls.py @@ -1,11 +1,16 @@ from django.urls import include, path from rest_framework.routers import DefaultRouter -from .api.viewsets import HospitalViewSet, PatientViewSet +from .api.viewsets import ( + HospitalViewSet, + PatientHospitalMappingViewset, + PatientViewSet, +) router = DefaultRouter() router.register(r"hospitals", HospitalViewSet) router.register(r"patients", PatientViewSet) +router.register(r"patient-hospital-mappings", PatientHospitalMappingViewset) urlpatterns = [ path("", include(router.urls)),