The idea is to train an RL agent based on the generated data and using a pre-trained model.

First we'd need to define the environment.

The environment would need to be able to do the following:

- Generate patient records
- Recieve questions about what symptoms a patient might have
- Return the appropriate rewards
    - 0 for an enquiry
    - -1 for a repeated enquiry (at which point the episode terminates)
    - +1 for a disease enquiry whose prediction correctly matches the disease.
        - might be an idea to make this top 5, (as long as the PR generated is > 0)
        - But first we try with top1

In [1]:
import json
import pathlib
import os
from collections import namedtuple

from itertools import repeat

import numpy as np
from scipy.sparse import csc_matrix

In [2]:
AiMedPatient = namedtuple('AiMedPatient', ('age', 'race', 'gender', 'symptoms', 'condition'))
AiMedState = namedtuple('AiMedState', ('age', 'race', 'gender', 'symptoms'))

In [3]:
class AiMedEnv:
    def __init__(
            self,
            data_file,
            symptom_map_file,
            condition_map_file,
            clf
    ):
        """
        data_file: A file of generated patient, symptoms, condition data
        symptom_map_file: the encoding file for symptoms
        condition_map_file: the encoding file for conditions
        initial_symptom_file: a map of conditions
        clf: a classifier which can output a probabilistic description of possible conditions based on
        symptoms and patient demography.
        """
        self.data_file = data_file
        self.symptom_map_file = symptom_map_file
        self.condition_map_file = condition_map_file
        self.clf = clf

        self.line_number = 0
        self.state = None
        self.patient = None
        self.data = None
        self.symptom_map = None
        self.condition_map = None
        self.initial_symptom_map = None
        self.num_symptoms = None
        self.num_conditions = None

        self.check_file_exists()

        self.load_data_file()
        self.load_symptom_map()
        self.load_condition_map()

        self.is_inquiry = 1
        self.is_diagnose = 2

        self.inquiry_list = set([])

        self.RACE_CODE = {'white': 0, 'black': 1, 'asian': 2, 'native': 3, 'other': 4}

    def check_file_exists(self):
        files = [self.data_file, self.symptom_map_file, self.condition_map_file]
        for file in files:
            if not os.path.exists(file):
                raise ValueError("File: %s does not exist" % file)

    def load_data_file(self):
        self.data = open(self.data_file)

    def close_data_file(self):
        if self.data is not None:
            self.data.close()

    def load_symptom_map(self):
        with open(self.symptom_map_file) as fp:
            symptoms = json.load(fp)
            sorted_symptoms = sorted(symptoms.keys())
            self.symptom_map = {code: idx for idx, code in enumerate(sorted_symptoms)}
            self.num_symptoms = len(self.symptom_map)

    def load_condition_map(self):
        with open(self.condition_map_file) as fp:
            conditions = json.load(fp)
            sorted_conditions = sorted(conditions.keys())
            self.condition_map = {code: idx for idx, code in enumerate(sorted_conditions)}
            self.num_conditions = len(self.condition_map)

    def read_line(self):
        if self.line_number == 0:
            self.data.readline()  # header line

        line = self.data.readline()
        if line is None:
            # EOF
            self.data.seek(0)
            self.data.readline()  # header line
            line = self.data.readline()

        self.line_number += 1
        return line.strip()

    def parse_line(self, line):
        parts = line.split(",")
        _gender = parts[1]
        _race = parts[2]

        age = int(parts[4])
        condition = parts[6]
        symptom_list = parts[8]

        gender = 0 if _gender == 'M' else 1
        race = self.RACE_CODE.get(_race)
        condition = self.condition_map.get(condition)
        symptoms = list(repeat(0, self.num_symptoms))
        for item in symptom_list.split(";"):
            idx = self.symptom_map.get(item)
            symptoms[idx] = 1
        # ('age', 'race', 'gender', 'symptoms', 'condition')
        symptoms = np.array(symptoms)
        patient = AiMedPatient(age, race, gender, symptoms, condition)
        return patient

    def reset(self):
        line = self.read_line()
        self.patient = self.parse_line(line)
        self.state = self.generate_state(
            self.patient.age,
            self.patient.race,
            self.patient.gender
        )
        self.inquiry_list = set([])

        self.pick_initial_symptom()

    def pick_initial_symptom(self):
        _existing_symptoms = np.where(self.patient.symptoms == 1)[0]

        initial_symptom = np.random.choice(_existing_symptoms)

        self.state.symptoms[initial_symptom] = np.array([0, 1, 0])
        self.inquiry_list.add(initial_symptom)

    def generate_state(self, age, race, gender):
        _symptoms = np.zeros((self.num_symptoms, 3), dtype=np.uint8)  # all symptoms start as unknown
        _symptoms[:, 2] = 1

        return AiMedState(age, race, gender, _symptoms)

    def is_valid_action(self, action):
        if action < self.num_symptoms:
            return True, self.is_inquiry, action  # it's an inquiry action
        else:
            action = action % self.num_symptoms

            if action < self.num_conditions:
                return True, self.is_diagnose, action  # it's a diagnose action

        return False, None, None

    def take_action(self, action):
        is_valid, action_type, action_value = self.is_valid_action(action)
        if not is_valid:
            raise ValueError("Invalid action: %s" % action)
        if action_type == self.is_inquiry:
            return self.inquire(action_value)
        else:
            return self.diagnose(action_value)

    def patient_has_symptom(self, symptom_idx):
        return self.patient.symptoms[symptom_idx] == 1

    def inquire(self, action_value):
        """
        returns state, reward, done
        """
        if action_value in self.inquiry_list:
            # repeated inquiry
            return self.state, -1, True  # we terminate on a repeated inquiry

        # does the patient have the symptom
        if self.patient_has_symptom(action_value):
            value = np.array([0, 1, 0])
        else:
            value = np.array([1, 0, 0])

        self.state.symptoms[action_value] = value
        self.inquiry_list.add(action_value)

        return self.state, 0, False

    def get_patient_vector(self):
        patient_vector = np.zeros(3 + self.num_symptoms, dtype=np.uint8)
        patient_vector[0] = self.state.gender
        patient_vector[1] = self.state.race
        patient_vector[2] = self.state.age

        has_symptom = np.where(self.state.symptoms[:, 1] == 1)[0] + 3
        patient_vector[has_symptom] = 1

        return patient_vector.reshape(1, -1)

    def predict_condition(self):
        patient_vector = self.get_patient_vector()
        patient_vector = csc_matrix(patient_vector)

        prediction = self.clf.predict(patient_vector)

        return prediction

    def diagnose(self, action_value):
        # enforce that there should be at least one inquiry in addition to the initial symptom
        if len(self.inquiry_list) < 2:
            return self.state, -1, True  # we always terminate on a repeated enquiry

        # we'll need to make a prediction
        prediction = self.predict_condition()[0]

        is_correct = action_value == prediction
        reward = 1 if is_correct else 0

        return self.state, reward, True

In [4]:
# test that the agent works as it should!

In [5]:
test_data_file = "/Users/teliov/TUD/Thesis/Medvice/Notebooks/data/06_18_nlice_plus/ai/output_med_ai_ext/symptoms/csv/test_symptoms.csv"

In [6]:
test_symptom_map_file = "/Users/teliov/TUD/Thesis/Medvice/Notebooks/data/06_18_nlice_plus/extended/symptom_db.json"

In [7]:
test_condition_map_file = "/Users/teliov/TUD/Thesis/Medvice/Notebooks/data/06_18_nlice_plus/extended/condition_db.json"

In [8]:
import joblib

# load the nb classifier
clf_file = "/Users/teliov/TUD/Thesis/Medvice/Notebooks/data/06_18_nlice_plus/extended/data/output/nb/nb_serialized_sparse.joblib"
clf_data = joblib.load(clf_file)
clf = clf_data.get("clf")

In [9]:
agent = AiMedEnv(
    data_file=test_data_file,
    symptom_map_file=test_symptom_map_file,
    condition_map_file=test_condition_map_file,
    clf=clf
)

In [10]:
# check that the initial calls worked
assert len(agent.inquiry_list) == 0, "Agent has already populated inquiry list"
assert agent.state is None, "State is not None just after init"
assert agent.data is not None, "Data file has not been loaded"
assert agent.patient is None, "Patient is not None just after init"
assert agent.line_number == 0, "Lines have been read from the file just after init!"
assert agent.symptom_map is not None,  "Symptom map has not been loaded"
assert agent.condition_map is not None, "Condition map has not been loaded"
assert len(agent.symptom_map) == 33,  "Symptom map not loaded properly"
assert len(agent.condition_map) == 14, "Condition map not loaded properly"

In [11]:
# check that parse and read line works
line = agent.read_line()

assert line is not None, "line is None!"
assert agent.line_number == 1, "Line number is not None"

In [12]:
# test parse line
patient = agent.parse_line(line)

assert patient.age == 70, "Invalid patient age read"
assert patient.gender == 0, "Invalid patient gender read"
assert patient.race == 0, "Invalid patient race read"

In [13]:
agent = AiMedEnv(
    data_file=test_data_file,
    symptom_map_file=test_symptom_map_file,
    condition_map_file=test_condition_map_file,
    clf=clf
)

In [14]:
# check that reset works
agent.reset()
assert agent.state is not None, "State is still None even after reset"
assert agent.patient is not None, "Patient is still None even after reset"
assert agent.line_number ==  1, "Agent has either not read a line or has read more than one line"

assert agent.patient.age == agent.state.age, "Patient age and state age are different"
assert agent.patient.gender == agent.state.gender, "Patient gender and state gender are different"
assert agent.patient.race == agent.patient.race, "Patient race and state race are different"

assert len(agent.inquiry_list) == 1, "There is more than one symptom in the inquiry list"

In [15]:
# check that is_valid action works
is_valid, action_type, action_value = agent.is_valid_action(0)

assert is_valid, "Expected a valid action"
assert action_type == agent.is_inquiry, "Expected an inquiry action"
assert action_value == 0, "Expected 0th action value"

is_valid, action_type, action_value = agent.is_valid_action(33)
assert is_valid, "Expected a valid action"
assert action_type == agent.is_diagnose, "Expected a diagnose action"
assert action_value == 0, "Expected 0th action value"

is_valid, action_type, action_value = agent.is_valid_action(47)
assert not is_valid, "Expected an invalid action"
assert action_type is None, "Expected None"
assert action_value is None, "Expected None"

In [16]:
# let's make some inquiries
# let's ask for pain relief and ask for back pain
# pain_relief: 25f7a9c449c1f24e3063fc27fb46c9dfd92b3ba3902b09f27baf9a36; 2
# back-pain: d6101ec36a1b500951f17ef90bc481691c69268d01647b3917df6836; 28

# pain_relief; patient does not have pain relief
pain_idx = 2
current_pain_state = agent.state.symptoms[pain_idx].copy()
assert np.array_equal(current_pain_state, np.array([0, 0, 1]))

state, reward, done = agent.inquire(pain_idx)
pain_state = state.symptoms[pain_idx].copy()

assert not np.array_equal(current_pain_state, pain_state)
assert np.array_equal(pain_state, np.array([1, 0, 0]))
assert reward == 0, "Expected 0 reward"
assert not done, "Expected done to be false"

# back-pain; patient does not have back-pain
back_pain_idx = 28
current_back_pain = agent.state.symptoms[back_pain_idx].copy()
assert np.array_equal(current_back_pain, np.array([0, 0, 1]))

state, reward, done = agent.inquire(back_pain_idx)
back_pain_state = state.symptoms[back_pain_idx].copy()

assert not np.array_equal(current_back_pain, back_pain_state)
assert np.array_equal(back_pain_state, np.array([1, 0, 0]))
assert reward == 0, "Expected 0 reward"
assert not done, "Expected done to be false"

In [17]:
# check that the inquiry list has been updated
assert len(agent.inquiry_list) == 3, "Expected inquiry list to be 3"

In [18]:
# check that patient_vector is well defined
patient_vector = agent.get_patient_vector()

In [19]:
num_rows, num_cols = patient_vector.shape
assert num_rows == 1, "Expecting 1 row"
assert num_cols == 36, "Expecting 36 columns"

In [20]:
assert patient_vector[0, 0] == agent.state.gender, "Gender value mismatch"
assert patient_vector[0, 1] == agent.state.race,  "Race value mismatch"
assert patient_vector[0, 2] == agent.state.age,  "Age value mismatch"

In [21]:
# check that headache is true and every other symptom is false
# headache: 67fe1b0607dced2d78d47eb7b8f2b599c0823043d54f0d875d9e5505; 12 (or 15 accounting for age, race and gender)

present_symptoms = np.where(patient_vector[0, :] == 1)[0]
assert len(present_symptoms) == 1, "got more present symptoms than expected"
assert present_symptoms[0] == 15, "present symptom is not headache"

In [22]:
# now need to check the prediction
# we know that the clf would predict tension type headache when headache is the only symptom present
# tension_type headache = 6, 61cc5f297c68dd09757d19a041cde247dcf693a6c9eeb770e16d4b84

In [23]:
prediction = agent.predict_condition()

assert prediction[0] == 6 , "Expected prediction to be tension type headache"

In [24]:
# test the diagnose action
# the action value for tension type headache will be
state, reward, done = agent.diagnose(6)

assert reward == 1, "Expected positive reward"
assert done , "Expected that we're done"

In [25]:
# test the take action method
# 39 corresponds to the action for tension type headache i.e 33 + 6
state, reward, done = agent.take_action(39)
assert reward == 1, "Expected positive reward"
assert done , "Expected that we're done"

In [26]:
# test that a repeateed inquiry ends the eniquiry with a negative response
# we'll inquire again about back pain
state, reward, done = agent.take_action(28)
assert reward == -1, "Expected positive reward"
assert done , "Expected that we're done"

In [27]:
# test that a wrong diagnosis returns zero reward but ends the process
# we'll inquire again about back pain
state, reward, done = agent.take_action(40)
assert reward == 0, "Expected zero reward"
assert done , "Expected that we're done"

In [28]:
# Agent is all good!