Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ help:
@echo " build to build the stack in docker compose."
@echo " run to run the stack in docker compose."
@echo " migrate to run the DB migrations in docker compose."
@echo " test to run the stack tests in docker compose."
@echo " test [target=<target-folder-or-file>] to run the stack tests in docker compose. Optionally specify a target."
@echo " test-one test_name=<test_name> to run a specific stack test in docker compose. Optionally, you can specify the test's file as well."
@echo " cover to print the code coverage in docker compose."
@echo " teardown to tear down the stack in docker compose."

Expand All @@ -27,7 +28,10 @@ migrate:
docker-compose -f ${COMPOSE_ENV}.yml exec -T django bash -c "python manage.py makemigrations && python manage.py migrate"

test:
docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest --disable-pytest-warnings;
docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest ${target} --disable-pytest-warnings;

test-one:
docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage run --rcfile=.pre-commit/setup.cfg -m pytest ${file} -k ${test_name} --disable-pytest-warnings;

cover:
docker-compose -f ${COMPOSE_ENV}.yml exec -T django coverage report
Expand Down
95 changes: 71 additions & 24 deletions tmh_registry/registry/api/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rest_framework.exceptions import ValidationError
from rest_framework.fields import IntegerField, SerializerMethodField
from rest_framework.fields import IntegerField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.serializers import ModelSerializer

from ..models import Hospital, Patient, PatientHospitalMapping
Expand All @@ -11,30 +12,15 @@ class Meta:
fields = ["id", "name", "address"]


class ReadPatientSerializer(ModelSerializer):
age = IntegerField(allow_null=True)
hospitals = SerializerMethodField()
class PatientHospitalMappingPatientSerializer(ModelSerializer):
class Meta:
model = PatientHospitalMapping
fields = ["patient_hospital_id", "hospital_id"]

def get_hospitals(self, obj):
hospitals = Hospital.objects.filter(
id__in=obj.hospital_mappings.all().values_list(
"hospital_id", flat=True
)
)
hospital_data = HospitalSerializer(hospitals, many=True).data

# enrich with patient_hospital_id
idx = 0
for hospital in hospital_data:
patient_hospital_id = PatientHospitalMapping.objects.get(
hospital_id=hospital["id"], patient_id=obj.id
).patient_hospital_id
hospital_data[idx].update(
{"patient_hospital_id": patient_hospital_id}
)
idx += 1

return hospital_data
class ReadPatientSerializer(ModelSerializer):
age = IntegerField(allow_null=True)
hospital_mappings = PatientHospitalMappingPatientSerializer(many=True)

class Meta:
model = Patient
Expand All @@ -50,7 +36,7 @@ class Meta:
"phone_1",
"phone_2",
"address",
"hospitals",
"hospital_mappings",
]

def to_representation(self, instance):
Expand Down Expand Up @@ -153,3 +139,64 @@ def create(self, validated_data):
)

return new_patient


class PatientHospitalMappingReadSerializer(ModelSerializer):
patient = ReadPatientSerializer()
hospital = HospitalSerializer()

class Meta:
model = PatientHospitalMapping
fields = ["patient", "hospital", "patient_hospital"]


class PatientHospitalMappingWriteSerializer(ModelSerializer):
patient_id = PrimaryKeyRelatedField(queryset=Patient.objects.all())
hospital_id = PrimaryKeyRelatedField(queryset=Hospital.objects.all())

class Meta:
model = PatientHospitalMapping
fields = ["patient_id", "hospital_id", "patient_hospital_id"]

def create(self, validated_data):
validated_data["patient_id"] = validated_data["patient_id"].id
validated_data["hospital_id"] = validated_data["hospital_id"].id

existing_mapping = PatientHospitalMapping.objects.filter(
patient_id=validated_data["patient_id"],
hospital_id=validated_data["hospital_id"],
)
if existing_mapping.exists():
raise ValidationError(
{
"error": "PatientHospitalMapping for patient_id {patient_id} and hospital_id {hospital_id} "
"already exists!".format(
patient_id=validated_data["patient_id"],
hospital_id=validated_data["hospital_id"],
)
}
)

existing_patient_hospital_id = PatientHospitalMapping.objects.filter(
patient_hospital_id=validated_data["patient_hospital_id"],
hospital_id=validated_data["hospital_id"],
)
if existing_patient_hospital_id.exists():
raise ValidationError(
{
"error": "Patient Hospital ID {patient_hospital_id} already exists for another patient in "
"this hospital".format(
patient_hospital_id=validated_data[
"patient_hospital_id"
]
)
}
)

new_mapping = PatientHospitalMapping.objects.create(
patient_hospital_id=validated_data["patient_hospital_id"],
hospital_id=validated_data["hospital_id"],
patient_id=validated_data["patient_id"],
)

return new_mapping
16 changes: 15 additions & 1 deletion tmh_registry/registry/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from rest_framework import mixins, viewsets
from rest_framework.viewsets import GenericViewSet

from ..models import Hospital, Patient
from ..models import Hospital, Patient, PatientHospitalMapping
from .serializers import (
CreatePatientSerializer,
HospitalSerializer,
PatientHospitalMappingReadSerializer,
PatientHospitalMappingWriteSerializer,
ReadPatientSerializer,
)

Expand Down Expand Up @@ -39,3 +41,15 @@ def get_serializer_class(self):
return CreatePatientSerializer

raise NotImplementedError


class PatientHospitalMappingViewset(mixins.CreateModelMixin, GenericViewSet):
queryset = PatientHospitalMapping.objects.all()

def get_serializer_class(self):
if self.action in ["list", "retrieve"]:
return PatientHospitalMappingReadSerializer
if self.action == "create":
return PatientHospitalMappingWriteSerializer

raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from django.test import TestCase
from rest_framework.authtoken.models import Token
from rest_framework.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
from rest_framework.test import APIClient

from tmh_registry.users.factories import MedicalPersonnelFactory

from ....factories import (
HospitalFactory,
PatientFactory,
PatientHospitalMappingFactory,
)


class TestPatientHospitalMappingViewset(TestCase):
def setUp(self) -> None:
self.token = Token.objects.create(user=MedicalPersonnelFactory().user)
self.client = APIClient()
self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)

def test_create_successful(self):
patient = PatientFactory()
hospital = HospitalFactory()
data = {
"patient_id": patient.id,
"hospital_id": hospital.id,
"patient_hospital_id": "blabla",
}

response = self.client.post(
"/api/v1/patient-hospital-mappings/", data=data, format="json"
)

self.assertEqual(HTTP_201_CREATED, response.status_code)
self.assertEqual(response.data["patient_id"], patient.id)
self.assertEqual(response.data["hospital_id"], hospital.id)
self.assertEqual(
response.data["patient_hospital_id"], data["patient_hospital_id"]
)

def test_create_when_mapping_already_exists(self):
mapping = PatientHospitalMappingFactory()
data = {
"patient_id": mapping.patient.id,
"hospital_id": mapping.hospital.id,
"patient_hospital_id": mapping.patient_hospital_id,
}

response = self.client.post(
"/api/v1/patient-hospital-mappings/", data=data, format="json"
)

self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code)

def test_create_when_mapping_already_exists_with_different_patient_hospital_id(
self,
):
mapping = PatientHospitalMappingFactory()
data = {
"patient_id": mapping.patient.id,
"hospital_id": mapping.hospital.id,
"patient_hospital_id": "whatever",
}

response = self.client.post(
"/api/v1/patient-hospital-mappings/", data=data, format="json"
)

self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code)

def test_create_when_patient_hospital_id_already_exists_for_another_patient(
self,
):
mapping = PatientHospitalMappingFactory()
patient = PatientFactory()
data = {
"patient_id": patient.id,
"hospital_id": mapping.hospital.id,
"patient_hospital_id": mapping.patient_hospital_id,
}

response = self.client.post(
"/api/v1/patient-hospital-mappings/", data=data, format="json"
)

self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code)
41 changes: 28 additions & 13 deletions tmh_registry/registry/tests/api/viewsets/test_patients.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,18 @@ def test_get_patients_list_successful(self):
self.assertEqual(self.patient.id, response.data["results"][0]["id"])
self.assertNotIn("hospital_id", response.data["results"][0])

self.assertEqual(1, len(response.data["results"][0]["hospitals"]))
self.assertEqual(
1, len(response.data["results"][0]["hospital_mappings"])
)
self.assertEqual(
self.mapping.patient_hospital_id,
response.data["results"][0]["hospitals"][0]["patient_hospital_id"],
response.data["results"][0]["hospital_mappings"][0][
"patient_hospital_id"
],
)
self.assertEqual(
self.hospital.id,
response.data["results"][0]["hospital_mappings"][0]["hospital_id"],
)

def test_get_patients_list_unauthorized(self):
Expand Down Expand Up @@ -112,16 +120,19 @@ def test_get_patients_detail_successful(self):

self.assertEqual(HTTP_200_OK, response.status_code)
self.assertEqual(self.patient.id, response.data["id"])
self.assertEqual(1, len(response.data["hospitals"]))
self.assertEqual(1, len(response.data["hospital_mappings"]))

patient_hospital_id = PatientHospitalMapping.objects.get(
hospital=self.hospital.id, patient=self.patient.id
).patient_hospital_id

self.assertEqual(self.hospital.id, response.data["hospitals"][0]["id"])
self.assertEqual(
self.hospital.id,
response.data["hospital_mappings"][0]["hospital_id"],
)
self.assertEqual(
patient_hospital_id,
response.data["hospitals"][0]["patient_hospital_id"],
response.data["hospital_mappings"][0]["patient_hospital_id"],
)

def test_get_patients_detail_unauthorized(self):
Expand Down Expand Up @@ -165,14 +176,15 @@ def test_get_patients_detail_with_multiple_hospitals(self):

self.assertEqual(HTTP_200_OK, response.status_code)
self.assertEqual(self.patient.id, response.data["id"])
self.assertEqual(2, len(response.data["hospitals"]))
self.assertEqual(2, len(response.data["hospital_mappings"]))

for hospital in response.data["hospitals"]:
for hospital_mapping in response.data["hospital_mappings"]:
patient_hospital_id = PatientHospitalMapping.objects.get(
hospital_id=hospital["id"], patient_id=self.patient.id
hospital_id=hospital_mapping["hospital_id"],
patient_id=self.patient.id,
).patient_hospital_id
self.assertEqual(
patient_hospital_id, hospital["patient_hospital_id"]
patient_hospital_id, hospital_mapping["patient_hospital_id"]
)

########################
Expand All @@ -195,11 +207,14 @@ def test_create_patients_successful(self):
self.assertEqual(data["phone_1"], response.data["phone_1"])
self.assertEqual(data["phone_2"], response.data["phone_2"])

self.assertEqual(1, len(response.data["hospitals"]))
self.assertEqual(self.hospital.id, response.data["hospitals"][0]["id"])
self.assertEqual(1, len(response.data["hospital_mappings"]))
self.assertEqual(
self.hospital.id,
response.data["hospital_mappings"][0]["hospital_id"],
)
self.assertEqual(
data["patient_hospital_id"],
response.data["hospitals"][0]["patient_hospital_id"],
response.data["hospital_mappings"][0]["patient_hospital_id"],
)
self.assertEqual(
1,
Expand Down Expand Up @@ -274,7 +289,7 @@ def test_create_patients_without_year_of_birth_but_with_age(self):
)
self.assertEqual(data["phone_1"], response.data["phone_1"])
self.assertEqual(data["phone_2"], response.data["phone_2"])
self.assertEqual(1, len(response.data["hospitals"]))
self.assertEqual(1, len(response.data["hospital_mappings"]))

def test_create_patients_without_year_of_birth_and_age(self):
data = self.get_patient_test_data()
Expand Down
7 changes: 6 additions & 1 deletion tmh_registry/registry/urls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from django.urls import include, path
from rest_framework.routers import DefaultRouter

from .api.viewsets import HospitalViewSet, PatientViewSet
from .api.viewsets import (
HospitalViewSet,
PatientHospitalMappingViewset,
PatientViewSet,
)

router = DefaultRouter()
router.register(r"hospitals", HospitalViewSet)
router.register(r"patients", PatientViewSet)
router.register(r"patient-hospital-mappings", PatientHospitalMappingViewset)

urlpatterns = [
path("", include(router.urls)),
Expand Down