In [1]:
import pickle
from tqdm import tqdm
import pandas as pd
import zstandard as zstd
import numpy as np
from DPTreeNode import DPTreeNode
import torch
from utils import (read_zst, parse_cols, calculate_top_entropy_columns, parse_age,
                   column_translate, pathology_translate)
from sklearn.preprocessing import MinMaxScaler
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
from LSTMModel import LSTMModel

# Load every needed in AD system
tree_root = pickle.load(open('./stored_files/inquiry_system.pkl', 'rb'))

input_size = 518  # number of features
hidden_size = 64
output_size = 49  # number of unique diagnoses
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_size, hidden_size, output_size).to(device)
model.load_state_dict(torch.load('./stored_files//best_model.pth'))
model.eval()
age_scaler = pickle.load(open('./stored_files/age_scaler.pkl', 'rb'))
column_translations = column_translate()
pathology_translations = pathology_translate()

In [17]:
# Load and process train data (for the inquiry system to infer)
# -------------------------------------------
train_wo_sex = read_zst('./stored_files/train_x.zst')
train_wo_sex = train_wo_sex.drop('SEX', axis=1).sort_index(axis=1)
train_age = train_wo_sex['AGE'].copy()
train_wo_sex['AGE'] = train_wo_sex['AGE'].apply(parse_age)

n = len(train_wo_sex)
train_wo_sex_tensor = torch.tensor(train_wo_sex.values, dtype=torch.int8).cuda()
age_idx = train_wo_sex.columns.get_loc('AGE')
original_column_names = train_wo_sex.columns.copy()
train_wo_sex.columns = range(len(train_wo_sex.columns))
train_wo_age = train_wo_sex.copy()
train_wo_age[age_idx] = 0

train_wo_age_tensor = train_wo_sex_tensor.clone()
train_wo_age_tensor[:, age_idx] = 0

train_wo_sex = train_wo_sex.replace(0, -1)

# Load and process test data (for demo)
# -------------------------------------------
test_wo_sex = read_zst('./stored_files/test_x.zst')
test_sex = test_wo_sex['SEX']
test_wo_sex = test_wo_sex.drop('SEX', axis=1).sort_index(axis=1)
test_age = test_wo_sex['AGE'].copy()
test_wo_sex['AGE'] = test_wo_sex['AGE'].apply(parse_age)

test_wo_sex.columns = range(len(test_wo_sex.columns))
test_wo_sex_tensor = torch.tensor(test_wo_sex.values, dtype=torch.int32).cuda()
test_wo_age_tensor = test_wo_sex_tensor.clone()
test_wo_age_tensor[:, age_idx] = 0
test_wo_sex = test_wo_sex.replace(0, -1)

# Parse column names
prefix_to_columns, column_to_prefix = parse_cols(original_column_names)
# Load test_y just to know all possible pathologies
# -------------------------------------------
test_y = read_zst('./stored_files/test_y.zst')
all_pathologies = test_y.columns
del test_y

In [3]:
# Create a patient instance for demo
single_inference = test_wo_sex.iloc[0]
instance_sex = 'F' if test_sex.iloc[0] else 'M'
instance_age = test_age.iloc[0]
instance_tensor = torch.tensor(single_inference.values, dtype=torch.int8).cuda()
instance_wo_age = instance_tensor.clone()
instance_wo_age[0] = 0

In [18]:
# For comments in this part please check inquiry_system.ipynb
result = None
num_question = 20
percent = 500
when_to_consider_age = 3
root = tree_root
reveal = []
similarity = None
patient_data = pd.Series(single_inference, index=original_column_names)
print(f'Sex: {instance_sex}, Age: {instance_age}')
print()
print('Agent inquiries:')
print('----------------')
for i in range(num_question):
    new_inquiry = root.next_question
    col_english = column_translations[original_column_names[new_inquiry]]
    print(f'- {col_english.split("_@_")[0]}')
    if column_to_prefix[new_inquiry]:
        new_inquiry = prefix_to_columns[column_to_prefix[new_inquiry]]
        log_answer = instance_tensor[new_inquiry].tolist()
        for response, q_num in zip(log_answer, new_inquiry):
            if response == 1:
                evidence = column_translations[original_column_names[q_num]]
                print(f'     {evidence.split("_@_")[1]}')
        if sum(log_answer) == -len(log_answer):
            print(f'     N')
    else:
        new_inquiry = [new_inquiry]
        if instance_tensor[new_inquiry].tolist()[0]:
            print(f'     Y')
        else:
            print(f'     N')
        
    reveal += new_inquiry

    if i == when_to_consider_age:
        answer = "".join(str(x) for x in instance_tensor[new_inquiry + [age_idx]].tolist())
        reveal += [age_idx]
    else:
        answer = "".join(str(x) for x in instance_tensor[new_inquiry].tolist())
        
    next_node = root.look_up_children(answer)
    if next_node:
        root = next_node
    else: 
        if i == num_question - 1: continue
        cur_node = DPTreeNode(answer, None)
        root.add_child(cur_node)
        if similarity is not None:
            if i == when_to_consider_age:
                for question_idx in new_inquiry:
                    similarity += 2 * (train_wo_age_tensor[:, question_idx] != instance_wo_age[question_idx])
                similarity += (train_wo_sex_tensor[:, age_idx] - instance_tensor[age_idx]).abs()
            else:
                for question_idx in new_inquiry:
                    similarity += 2 * (train_wo_age_tensor[:, question_idx] != instance_wo_age[question_idx])
        else:
            similarity = torch.zeros(train_wo_sex_tensor.size(0), device="cuda")
            for question_idx in reveal:
                if question_idx == age_idx:
                    similarity += (train_wo_sex_tensor[:, age_idx] - instance_tensor[age_idx]).abs()
                else:
                    similarity += 2 * (train_wo_age_tensor[:, question_idx] != instance_wo_age[question_idx])
        k = n//10000*percent
        values, top_indices = torch.topk(similarity, k, largest=False)
        top_actual_indices = top_indices
        similar_cases = train_wo_age_tensor[top_actual_indices]
        selected_col = [col for col in train_wo_sex.columns if col not in reveal]
        new_inquiry, en = calculate_top_entropy_columns(
            similar_cases[:, selected_col], selected_col, prefix_to_columns, column_to_prefix)
        cur_node.next_question = new_inquiry
        cur_node.entropy = np.int8(en//2)
        root = cur_node

    percent = percent//(1.4**(root.entropy + 1))
    percent = int(percent) if percent > 1 else 1
    torch.cuda.synchronize()

mask_instance = pd.Series(0, index=range(len(original_column_names)))
mask_instance.iloc[reveal] = single_inference.iloc[reveal]
mask_instance.index= original_column_names
mask_instance = mask_instance.astype(np.float64)
mask_instance['AGE'] = age_scaler.transform(np.array(instance_age).reshape(1, -1))[0][0]
mask_instance['SEX'] = test_sex.iloc[0]
# convert test data to PyTorch tensors and move them to the device
X_test_tensor = torch.tensor(pd.DataFrame(mask_instance).T.values, dtype=torch.float).to(device)
with torch.no_grad():
    # generate differential diagnosis
    for features in DataLoader(X_test_tensor, batch_size=32):
        features = features.unsqueeze(1) 
        outputs = model.forward(features)
        predictions = torch.sigmoid(outputs).round()
print()
print('Predicted Differential:')
print('-----------------------')
prediction = all_pathologies[predictions.cpu().numpy()[0].astype(bool)]
print(', '.join([pathology_translations[french_path] for french_path in prediction]))

Sex: F, Age: 49

Agent inquiries:
----------------
- Do you feel pain somewhere?
     lower chest
     upper chest
     hypochondrium(R)
- Does the pain radiate to another location?
     lower chest
     upper chest
- Characterize your pain:
     haunting
     sensitive
     tugging
     burning
- How intense is the pain?
     6
- How fast did the pain appear?
     2
- How precisely is the pain located?
     3
- Do you drink alcohol excessively or do you have an addiction to alcohol?
     Y
- Do you have a hiatal hernia?
     Y
- Are you significantly overweight compared to people of the same height as you?
     Y
- Do you have asthma or have you ever had to use a bronchodilator in the past?
     Y
- Are your symptoms worse when lying down and alleviated while sitting up?
     Y
- Do you think you are pregnant or are you currently pregnant?
     Y
- Do you smoke cigarettes?
     Y
- Have you recently thrown up blood or something resembling coffee beans?
     Y
- Do you have a burning s