Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/local.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pytest-sugar==0.9.4 # https://github.com/Frozenball/pytest-sugar
mock==4.0.3 # https://pypi.org/project/mock/
freezegun==1.1.0 # https://pypi.org/project/freezegun/0.1.11/
faker==8.5.1 # https://github.com/joke2k/faker
parameterized==0.8.1 # https://github.com/wolever/parameterized

# Code quality
# ------------------------------------------------------------------------------
Expand Down
106 changes: 105 additions & 1 deletion tmh_registry/registry/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rest_framework.serializers import ModelSerializer, SerializerMethodField

from ...users.api.serializers import MedicalPersonnelSerializer
from ...users.models import MedicalPersonnel
from ..models import Episode, Hospital, Patient, PatientHospitalMapping


Expand Down Expand Up @@ -180,7 +181,7 @@ class PatientHospitalMappingReadSerializer(ModelSerializer):

class Meta:
model = PatientHospitalMapping
fields = ["patient", "hospital", "patient_hospital"]
fields = ["patient", "hospital", "patient_hospital_id"]


class PatientHospitalMappingWriteSerializer(ModelSerializer):
Expand All @@ -191,6 +192,10 @@ class Meta:
model = PatientHospitalMapping
fields = ["patient_id", "hospital_id", "patient_hospital_id"]

def to_representation(self, instance):
serializer = PatientHospitalMappingReadSerializer(instance)
return serializer.data

def create(self, validated_data):
validated_data["patient_id"] = validated_data["patient_id"].id
validated_data["hospital_id"] = validated_data["hospital_id"].id
Expand Down Expand Up @@ -233,3 +238,102 @@ def create(self, validated_data):
)

return new_mapping


class EpisodeReadSerializer(ModelSerializer):
patient_hospital_mapping = PatientHospitalMappingReadSerializer()
surgeons = MedicalPersonnelSerializer(many=True)

class Meta:
model = Episode
fields = [
"patient_hospital_mapping",
"created",
"surgery_date",
"episode_type",
"surgeons",
"comments",
"cepod",
"side",
"occurence",
"type",
"complexity",
"mesh_type",
"anaesthetic_type",
"diathermy_used",
]


class EpisodeWriteSerializer(ModelSerializer):
patient_id = PrimaryKeyRelatedField(
write_only=True, queryset=Patient.objects.all()
)
hospital_id = PrimaryKeyRelatedField(
write_only=True, queryset=Hospital.objects.all()
)
surgeon_ids = PrimaryKeyRelatedField(
write_only=True, many=True, queryset=MedicalPersonnel.objects.all()
)

class Meta:
model = Episode
fields = [
"patient_id",
"hospital_id",
"surgery_date",
"episode_type",
"surgeon_ids",
"comments",
"cepod",
"side",
"occurence",
"type",
"complexity",
"mesh_type",
"anaesthetic_type",
"diathermy_used",
]

def to_representation(self, instance):
serializer = EpisodeReadSerializer(instance)
return serializer.data

def create(self, validated_data):
patient = validated_data["patient_id"]
hospital = validated_data["hospital_id"]
surgeons = validated_data["surgeon_ids"]

patient_hospital_mapping = PatientHospitalMapping.objects.filter(
patient_id=patient.id,
hospital_id=hospital.id,
).first()
if patient_hospital_mapping is None:
raise ValidationError(
{
"error": "PatientHospitalMapping for patient_id {patient_id} and hospital_id {hospital_id} "
"does not exist.".format(
patient_id=patient.id,
hospital_id=hospital.id,
)
}
)

episode = Episode.objects.create(
patient_hospital_mapping=patient_hospital_mapping,
surgery_date=validated_data["surgery_date"],
episode_type=validated_data["episode_type"],
comments=validated_data["comments"],
cepod=validated_data["cepod"],
side=validated_data["side"],
occurence=validated_data["occurence"],
type=validated_data["type"],
complexity=validated_data["complexity"],
mesh_type=validated_data["mesh_type"],
anaesthetic_type=validated_data["anaesthetic_type"],
diathermy_used=validated_data["diathermy_used"],
)

if surgeons:
episode.surgeons.set(surgeons)

return episode
31 changes: 28 additions & 3 deletions tmh_registry/registry/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from drf_yasg.openapi import IN_QUERY, TYPE_INTEGER, TYPE_STRING, Parameter
from drf_yasg.utils import swagger_auto_schema
from rest_framework import mixins, viewsets
from rest_framework.mixins import CreateModelMixin
from rest_framework.viewsets import GenericViewSet

from ..models import Hospital, Patient, PatientHospitalMapping
from ..models import Episode, Hospital, Patient, PatientHospitalMapping
from .serializers import (
CreatePatientSerializer,
EpisodeReadSerializer,
EpisodeWriteSerializer,
HospitalSerializer,
PatientHospitalMappingReadSerializer,
PatientHospitalMappingWriteSerializer,
Expand Down Expand Up @@ -91,7 +94,7 @@ class Meta:
type=TYPE_INTEGER,
),
],
responses={200: ReadPatientSerializer(many=True)},
responses={200: ReadPatientSerializer()},
),
)
class PatientViewSet(
Expand All @@ -112,7 +115,13 @@ def get_serializer_class(self):
raise NotImplementedError


class PatientHospitalMappingViewset(mixins.CreateModelMixin, GenericViewSet):
@method_decorator(
name="create",
decorator=swagger_auto_schema(
responses={201: PatientHospitalMappingReadSerializer()}
),
)
class PatientHospitalMappingViewset(CreateModelMixin, GenericViewSet):
queryset = PatientHospitalMapping.objects.all()

def get_serializer_class(self):
Expand All @@ -122,3 +131,19 @@ def get_serializer_class(self):
return PatientHospitalMappingWriteSerializer

raise NotImplementedError


@method_decorator(
name="create",
decorator=swagger_auto_schema(responses={201: EpisodeReadSerializer()}),
)
class EpisodeViewset(CreateModelMixin, GenericViewSet):
queryset = Episode.objects.all()

def get_serializer_class(self):
if self.action in ["list", "retrieve"]:
return EpisodeReadSerializer
if self.action == "create":
return EpisodeWriteSerializer

raise NotImplementedError
123 changes: 123 additions & 0 deletions tmh_registry/registry/tests/api/viewsets/test_episodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from datetime import date

from django.test import TestCase
from parameterized import parameterized
from rest_framework.authtoken.models import Token
from rest_framework.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
from rest_framework.test import APIClient

from tmh_registry.registry.factories import (
HospitalFactory,
PatientFactory,
PatientHospitalMappingFactory,
)
from tmh_registry.registry.models import Episode, PatientHospitalMapping
from tmh_registry.users.factories import MedicalPersonnelFactory


class TestEpisodesViewSet(TestCase):
@classmethod
def setUpClass(cls) -> None:
super(TestEpisodesViewSet, cls).setUpClass()

cls.hospital = HospitalFactory()

cls.patient = PatientFactory(full_name="John Doe")
cls.patient.created_at = date(year=2021, month=4, day=11)
cls.patient.save()

cls.medical_personnel = MedicalPersonnelFactory()
cls.token = Token.objects.create(user=cls.medical_personnel.user)

def get_episode_test_data(self):
return {
"patient_id": self.patient.id,
"hospital_id": self.hospital.id,
"surgery_date": "2021-10-12",
"episode_type": Episode.EpisodeChoices.UMBILICAL.value,
"surgeon_ids": [self.medical_personnel.id],
"comments": "A random comment",
"cepod": Episode.CepodChoices.PLANNED.value,
"side": Episode.SideChoices.LEFT.value,
"occurence": Episode.OccurenceChoices.RECURRENT.value,
"type": Episode.TypeChoices.INDIRECT.value,
"complexity": Episode.ComplexityChoices.INCARCERATED.value,
"mesh_type": Episode.MeshTypeChoices.TNMHP.value,
"anaesthetic_type": Episode.AnaestheticChoices.SPINAL.value,
"diathermy_used": True,
}

def setUp(self) -> None:
self.client = APIClient()
self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)

self.patient_hospital_mapping = PatientHospitalMappingFactory(
patient=self.patient, hospital=self.hospital
)

def test_create_episode_when_no_patient_hospital_mapping_exists(self):
PatientHospitalMapping.objects.all().delete()

data = self.get_episode_test_data()
response = self.client.post(
"/api/v1/episodes/", data=data, format="json"
)

self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code)

@parameterized.expand(
[
("episode_type", "hIaTuS"),
("cepod", "WRONG_OPTION"),
("side", "WRONG_OPTION"),
("occurence", "WRONG_OPTION"),
("type", "WRONG_OPTION"),
("complexity", "WRONG_OPTION"),
("mesh_type", "WRONG_OPTION"),
("anaesthetic_type", "WRONG_OPTION"),
]
)
def test_with_non_acceptable_value_for_a_field(self, field, value):
data = self.get_episode_test_data()
data[field] = value

response = self.client.post(
"/api/v1/episodes/", data=data, format="json"
)

self.assertEqual(HTTP_400_BAD_REQUEST, response.status_code)

def test_create_episode_successful(self):
data = self.get_episode_test_data()
response = self.client.post(
"/api/v1/episodes/", data=data, format="json"
)

self.assertEqual(HTTP_201_CREATED, response.status_code)

self.assertEqual(
response.data["patient_hospital_mapping"]["patient_hospital_id"],
self.patient_hospital_mapping.patient_hospital_id,
)
self.assertEqual(response.data["surgery_date"], data["surgery_date"])
self.assertEqual(response.data["episode_type"], data["episode_type"])

self.assertEqual(len(response.data["surgeons"]), 1)
self.assertEqual(
response.data["surgeons"][0]["user"]["email"],
self.medical_personnel.user.email,
)

self.assertEqual(response.data["comments"], data["comments"])
self.assertEqual(response.data["cepod"], data["cepod"])
self.assertEqual(response.data["side"], data["side"])
self.assertEqual(response.data["occurence"], data["occurence"])
self.assertEqual(response.data["type"], data["type"])
self.assertEqual(response.data["complexity"], data["complexity"])
self.assertEqual(response.data["mesh_type"], data["mesh_type"])
self.assertEqual(
response.data["anaesthetic_type"], data["anaesthetic_type"]
)
self.assertEqual(
response.data["diathermy_used"], data["diathermy_used"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_create_successful(self):
)

self.assertEqual(HTTP_201_CREATED, response.status_code)
self.assertEqual(response.data["patient_id"], patient.id)
self.assertEqual(response.data["hospital_id"], hospital.id)
self.assertEqual(response.data["patient"]["id"], patient.id)
self.assertEqual(response.data["hospital"]["id"], hospital.id)
self.assertEqual(
response.data["patient_hospital_id"], data["patient_hospital_id"]
)
Expand Down
2 changes: 2 additions & 0 deletions tmh_registry/registry/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rest_framework.routers import DefaultRouter

from .api.viewsets import (
EpisodeViewset,
HospitalViewSet,
PatientHospitalMappingViewset,
PatientViewSet,
Expand All @@ -11,6 +12,7 @@
router.register(r"hospitals", HospitalViewSet)
router.register(r"patients", PatientViewSet)
router.register(r"patient-hospital-mappings", PatientHospitalMappingViewset)
router.register(r"episodes", EpisodeViewset)

urlpatterns = [
path("", include(router.urls)),
Expand Down