## Reinforcement Learning Environment

In [5]:
import numpy as np
from typing import Dict, Tuple, List
import csv
from dataclasses import dataclass

In [31]:
@dataclass
class Transition:
    state: np.array
    action: str 
    next_state: np.array
    reward: float 

class Env: 
    _condition_symptom_probabilities: Dict[str, Dict[str, float]] # conditions with symptoms and their probabilities
    _actions: set[str] # symptoms
    _init_state: np.array
    _current_state: np.array
    _img: np.array
    _condition: str
    _symptoms_of_condition: Dict[str, float] # symptoms of conditions
    
    def __init__(self,
                 img: np.array,
                 condition: str,
                ) -> None:  
        self._img = img
        self._condition = condition

        # init condition_symptom_probabilities from health knowledge graph
        #self._condition_symptom_probabilities= dict()
        #with open('HealthKnowledgeGraph.csv', newline='') as csvfile:
        #    reader = csv.reader(csvfile, delimiter=',')
        #    reader.__next__() # skip header
        #    for condition in reader:
        #        self._condition_symptom_probabilities[condition[0]] = dict()
        #        for symptom_prob in condition[1].split(','):
        #            # examples for symptom_prob: pain (0.318), fever (0.119) or swelling (0.112)
        #            symptom = symptom_prob.split('(')[0].strip()
        #            prob = float(symptom_prob.split('(')[1].split(')')[0])
        #            self._condition_symptom_probabilities[condition[0]][symptom] = prob

        # init condition_symptom_probabilities from slake knowledge graph
        supported_diseases=["Alveolar Proteinosis", "Pertussis", "Lobar Pneumonia"] 
        self._condition_symptom_probabilities= dict()
        with open('Slake1.0/KG/en_disease.csv', newline='') as csvfile:
            reader = csv.reader(csvfile, delimiter='#')
            reader.__next__() # skip header 
            for row in reader:
                if(row[1]!="symptom"):
                    continue
                if(row[0] not in supported_diseases):
                    continue
                self._condition_symptom_probabilities[row[0]] = dict()
                n_symptoms=len(row[2].split(','))
                uniform_prob = 1/(2**n_symptoms)
                for symptom in row[2].split(','):
                    #assign uniform conditional probability because no conditional probability are available 
                    self._condition_symptom_probabilities[row[0]][symptom.strip()] = uniform_prob

        # check if condition is valid
        if(self._condition not in self._condition_symptom_probabilities.keys()):
            raise ValueError('Unknow Condition: ' + condition + '. Please choose one of the following: ' + str(self._condition_symptom_probabilities.keys()))
        
        # init symptoms_of_condition for easier access
        self._symptoms_of_condition = dict()
        for symptom in self._condition_symptom_probabilities[self._condition]:
            self._symptoms_of_condition[symptom] = self._condition_symptom_probabilities[self._condition][symptom] 
    
        # init actions
        self._actions = set()
        for condition in self._condition_symptom_probabilities.keys(): 
            for symptom in list(self._condition_symptom_probabilities[condition]): 
                self._actions.add(symptom)   

        # init init_state = vector with cnn output (probabilities per condition) and history of asked symptoms (0=not asked, 1=symptom is present, -1=symptom is not present)
        visual_prior = np.random.uniform(size=(len(self._condition_symptom_probabilities.keys()))) #TODO: replace with cnn output
        self._init_state = np.concatenate((visual_prior,np.zeros((len(self._actions)))), axis=0)
        self._current_state = self._init_state 

    def posterior_of_condition(self, condition: str) -> float: 
        #TODO: What is the correct likelihood calculation? If we use multiplication as in P(x,y)=P(x)*P(y), the likelihood gets smaller 
        #and nothing prevents the model from asking symptoms which are not related to the condition.
        likelihood=0
        for idx, symptom in enumerate(self._actions):
            patient_answer = self._current_state[idx+len(self._condition_symptom_probabilities.keys())]
            #if (patient_answer==1) and (symptom not in self._condition_symptom_probabilities[condition].keys()):
            #    likelihood*= 0
            #elif (patient_answer==-1) and (symptom not in self._condition_symptom_probabilities[condition].keys()):
            #    likelihood*=1
            if (symptom not in self._condition_symptom_probabilities[condition].keys()):
                #TODO: Do we have to punish the model if a symptom is positive and is not related to the condition?
                continue
            elif patient_answer==1:
                likelihood+=self._condition_symptom_probabilities[condition][symptom]
            elif patient_answer==-1:
                likelihood+=(1-self._condition_symptom_probabilities[condition][symptom]) 

        prior = self._current_state[list(self._condition_symptom_probabilities.keys()).index(condition)]
        return likelihood*prior
    
    def reward(self) -> float:
        #TODO: Is it a problem when the reward gets smaller and smaller?
        return self.posterior_of_condition(self._condition)
    
    def has_symptom(self, symptom: str) -> bool:
        if symptom not in self._symptoms_of_condition:
            return False
        else:
            phi = np.random.uniform()
            return phi <= self._symptoms_of_condition[symptom]

    def step(self, action: str) -> Transition:
        #check if action is valid
        if(action not in self._actions):
            raise ValueError('Unknow Action: ' + action + '. Please choose one of the following: ' + str(self._actions))
        
        old_state = self._current_state.copy()
        self._current_state[len(self._condition_symptom_probabilities.keys()) + list(self._actions).index(action)] = 1 if self.has_symptom(action) else -1
        
        return Transition(old_state, action, self._current_state, self.reward())
    
    def reset(self) -> None:
        self._current_state = self._init_state

### Test cases

In [33]:
#Testing simulated patient answers
myEnv=Env(np.array([]), 'Pertussis')
print("Symptoms for Pertussis:")
print(myEnv._condition_symptom_probabilities['Pertussis'])
print("Expected uniform conditional proabability: 1\(", 2**len(myEnv._condition_symptom_probabilities['Pertussis'].keys()), ")")
n=0
prob=0
for i in range(10000):
    n+=1
    if myEnv.has_symptom('spastic cough'):
        prob+=1 
print("\n Probability of spastic cough after 10000 samples: " + str(prob/n))

Symptoms for Pertussis:
{'spastic cough': 0.5}
Expected uniform conditional proabability: 1\( 2 )

 Probability of spastic cough after 10000 samples: 0.5042


In [27]:
#Testing reward
myEnv=Env(np.array([]), 'Alveolar Proteinosis')
print("prior of condition:")
print(myEnv._current_state[list(myEnv._condition_symptom_probabilities.keys()).index("Alveolar Proteinosis")])

myEnv.step('cyanosis')
result=myEnv._current_state[len(myEnv._condition_symptom_probabilities.keys()) + list(myEnv._actions).index('cyanosis')] 
print("Probability of cyanosis: " + str(myEnv._condition_symptom_probabilities['Alveolar Proteinosis']['cyanosis']))
print("Result patient asking if he has cyanosis: " + str(result))

myEnv.step('chest pain')
result=myEnv._current_state[len(myEnv._condition_symptom_probabilities.keys()) + list(myEnv._actions).index('chest pain')] 
print("Probability of chest pain: " + str(myEnv._condition_symptom_probabilities['Alveolar Proteinosis']['chest pain']))
print("Result patient asking if he has chest pain: " + str(result))

myEnv.step('dyspnea')
result=myEnv._current_state[len(myEnv._condition_symptom_probabilities.keys()) + list(myEnv._actions).index('dyspnea')] 
print("Probability of dyspnea: " + str(myEnv._condition_symptom_probabilities['Alveolar Proteinosis']['dyspnea']))
print("Result patient asking if he has dyspnea: " + str(result))

print("Reward: " + str(myEnv.reward()))

prior of condition:
0.2845665892817385
Probability of cyanosis: 0.8
Result patient asking if he has cyanosis: 1.0
Probability of chest pain: 0.8
Result patient asking if he has chest pain: 1.0
Probability of dyspnea: 0.8
Result patient asking if he has dyspnea: -1.0
Reward: 0.03642452342806252
