# Reproduction of Context-aware Health Event Prediction via Transition Functions on Dynamic Disease Graphs (Lu et al., 2022)

**UIUC, CS598 DL4H, Spring 2023**

**Authors:** Shiyu (Sherry) Li and Wei-Lun (Will) Tsai; {shiyuli2, wltsai2}@illinois.edu

**Original paper:** Chang Lu, Tian Han, and Yue Ning. 2022. [Context-aware Health Event Prediction
via Transition Functions on Dynamic Disease
Graphs.](https://arxiv.org/pdf/2112.05195.pdf) Proceedings of the AAAI
Conference on Artificial Intelligence, 36(4):4567–4574.

**Original codebase:** [github.com/LuChang-CS/Chet](https://github.com/LuChang-CS/Chet)

## 1. Preprocess the data

As a part of the initial setup described in the [README](https://github.com/willtsai/uiuc-cs598-dlh/blob/main/README.md), we have downloaded the raw data and placed it in the `data` directory. We will now preprocess the data to be used in the model.

In [1]:
from run_preprocess import pre_process

pre_process(dataset_names=['mimic3','mimic4'], data_saved=False)
print('***processing complete***')

parsing the csv file of admission ...
	58976 in 58976 rows
parsing csv file of diagnosis ...
	651047 in 651047 rows
calibrating patients by admission ...
calibrating admission by patients ...
saving parsed data ...
patient num: 7493
max admission num: 42
mean admission num: 2.66
max code num in an admission: 39
mean code num in an admission: 13.06
max code num in a visit: 39
encoding code ...
There are 4880 codes
generating code levels ...
	100%00%
There are 6000 train, 493 valid, 1000 test samples
generating code code adjacent matrix ...
	6000 / 6000
building train codes features and labels ...
	6000 / 6000
building valid codes features and labels ...
	493 / 493
building test codes features and labels ...
	1000 / 1000
building train codes features and labels for CGL...
building train/valid/test codes features and labels for CGL...
	6000 / 6000
building valid codes features and labels for CGL...
building train/valid/test codes features and labels for CGL...
	493 / 493
building test cod

ValueError: Usecols do not match columns, columns expected but not found: ['admittime', 'hadm_id', 'subject_id']


## 2. Set hyperparameters and seed

We keep the same hyperparameters and seed as the original paper.

In [None]:
import torch
import numpy as np
import random

# Keep the same hyperparameters and seed as the original paper
code_size = 48
graph_size = 32
hidden_size = 150  # rnn hidden size
t_attention_size = 32
t_output_size = hidden_size
batch_size = 32
seed = 6669
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Config for hardware to use
if torch.cuda.is_available():
    device = torch.device('cuda')
# elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

## 3. Model

Here we implement the model and its layers as described in the paper.

### 3.1 Optimized dynamic graph layer

Here we define the optimized dynamic graph layer for the model. This layer performs the following steps:
- Aggregate global/local context with the optimized graph layer with the embedding matrices
- Calculate hidden embeddings for diagnoses and neighbors

In [None]:
import torch
import torch.nn as nn

class GraphLayer(nn.Module):
    def __init__(self, adj, code_num, code_size, graph_size):
        super().__init__()
        self.embedding =  nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(code_num, code_size)))
        self.adj = adj 
        # Fully connected layer
        self.fc = nn.Linear(code_size, graph_size)
        self.LeakyReLU = nn.LeakyReLU()

    def forward(self, code_x, neighbor):
        # embedding matrices for for diseases appearing in current diagnoses
        # M_embedding_matrices = self.embedding(code_x)
        # embedding matrices for for diseases appearing in direct neighbors
        # N_embedding_matrices = self.embedding(neighbor)
        # static adjacency matrix
        # keep these unsqueeze for now, may need change if we change the data loader
        center_codes = torch.unsqueeze(code_x, dim=-1)
        neighbor_codes = torch.unsqueeze(neighbor, dim=-1)

        center_embeddings = center_codes * self.embedding
        neighbor_embeddings = neighbor_codes * self.embedding

        adj_mul_center = torch.matmul(self.adj, center_embeddings)
        adj_mul_neighbor = torch.matmul(self.adj, neighbor_embeddings)

        # All the calculation here are using the memory-efficient calculation as proved by the author in Subgraphs' Adjacency Matrix Calculation
        # aggregated diagnosis local context and diagnosis global context
        aggregated_diagnosis_embedding = center_embeddings + center_codes * adj_mul_center + center_codes * adj_mul_neighbor
        # aggregated neighbor global context
        aggregated_neighbor_embedding = neighbor_embeddings + neighbor_codes * adj_mul_neighbor + neighbor_codes * adj_mul_center

        # hidden embeddings of diagnoses and neighbors
        hidden_diagnosis_embedding = self.LeakyReLU(self.fc(aggregated_diagnosis_embedding))
        hidden_neighbor_embedding = self.LeakyReLU(self.fc(aggregated_neighbor_embedding))
        return hidden_diagnosis_embedding, hidden_neighbor_embedding


### 3.2 Transition functions layer

Here we define the transition functions layer for the model. The hidden embeddings from the optimized dynamic graph layer are used as inputs to this layer. This layer includes GRU, M-GRU (customized GRU for matrices), and single headed attention functions.

In [None]:
import math

class SingleHeadAttentionLayer(nn.Module):
    def __init__(self, query_size, key_size, value_size, attention_size):
        super().__init__()
        self.attention_size = attention_size
        self.query_dense = nn.Linear(query_size, attention_size)
        self.key_dense = nn.Linear(key_size, attention_size)
        self.value_dense = nn.Linear(query_size, value_size)
        
    def forward(self, q, k, v):
        query = self.query_dense(q)
        key = self.key_dense(k)
        value = self.value_dense(v)
        attention = torch.matmul(query, key.T) / math.sqrt(self.attention_size)
        attention = torch.softmax(attention, dim=-1)
        output = torch.matmul(attention, value)
        return output
    
class TransitionLayer(nn.Module):
    def __init__(self, code_num, code_size, graph_size, hidden_size, t_attention_size, t_output_size):
        super().__init__()
        self.unrelated_embedding = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(code_num, graph_size)))
        self.gru = nn.GRUCell(input_size=graph_size, hidden_size=hidden_size)
        self.attention = SingleHeadAttentionLayer(graph_size, graph_size, t_output_size, t_attention_size)
        self.activation = nn.Tanh()

        self.code_num = code_num
        self.hidden_size = hidden_size
        self.code_size = code_size

    def forward(self, t, co_embeddings, divided, no_embeddings, hidden_state=None):
        m_p, m_en, m_eu = divided[:, 0], divided[:, 1], divided[:, 2]
        mp_idx, men_idx, meu_idx = torch.where(m_p > 0)[0], torch.where(m_en > 0)[0], torch.where(m_eu > 0)[0]
        h_new = torch.zeros((self.code_num, self.hidden_size), dtype=co_embeddings.dtype).to(co_embeddings.device)
        output_mp = 0
        output_meneu = 0

        if len(mp_idx) > 0:
            h = hidden_state[mp_idx] if hidden_state is not None else None
            h_p = self.gru(co_embeddings[mp_idx], h)
            h_new[mp_idx] = h_p
            output_mp, _ = torch.max(h_p, dim=-2)
        if t == 0 or len(men_idx) + len(meu_idx) == 0:
            output = output_mp
        else:
            q = torch.vstack([no_embeddings[men_idx], self.unrelated_embedding[meu_idx]])
            v = torch.vstack([co_embeddings[men_idx], co_embeddings[meu_idx]])
            h_tilda = self.activation(self.attention(q, q, v))
            h_new[men_idx] = h_tilda[:len(men_idx)]
            h_new[meu_idx] = h_tilda[len(men_idx):]
            output_meneu, _ = torch.max(h_tilda, dim=-2)
            if len(mp_idx) == 0:
                output = output_meneu
            else:
                output, _ = torch.max(torch.vstack([output_mp, output_meneu]), dim=-2)

        return output, h_new

### 3.3 Embedding layer

Here we define the embedding layer for the model, combined with the dot product attention activation for this layer.

In [None]:
import torch.nn as nn

class EmbeddingWithAttentionLayer(nn.Module):
    def __init__(self, value_size, attention_size):
        super().__init__()
        self.attention_size = attention_size
        # define context vector
        self.context = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(attention_size, 1)))
        self.linear = nn.Linear(value_size, attention_size)

    def forward(self, x):
        # max pooling
        t = self.linear(x)
        # calculate attention score
        score = torch.softmax(torch.matmul(t, self.context).squeeze(), dim=-1)
        # final hidden embedding
        output = torch.sum(x * torch.unsqueeze(score, dim=-1), dim=-2)
        return output


### 3.4 Model and classifier

In [None]:
class Classifier(nn.Module):
    def __init__(self, input_size, output_size, dropout_rate):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.activation = torch.nn.Sigmoid()
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        output = self.activation(self.dropout(self.linear(x)))
        return output

In [None]:
# still need further editing to integrate
class Model(nn.Module):
    def __init__(self, code_num, code_size,
                 adj, graph_size, hidden_size, t_attention_size, t_output_size,
                 output_size, dropout_rate):
        super().__init__()
        self.graph_layer = GraphLayer(adj, code_num, code_size, graph_size)
        self.transition_layer = TransitionLayer(code_num, code_size, graph_size, hidden_size, t_attention_size, t_output_size)
        self.attention = EmbeddingWithAttentionLayer(hidden_size, 32)
        self.classifier = Classifier(hidden_size, output_size, dropout_rate)

    def forward(self, code_x, divided, neighbors, lens):
        output = []
        for code_x_i, divided_i, neighbor_i, len_i in zip(code_x, divided, neighbors, lens):
            no_embeddings_i_prev = None
            output_i = []
            h_t = None
            for t, (c_it, d_it, n_it, len_it) in enumerate(zip(code_x_i, divided_i, neighbor_i, range(len_i))):
                co_embeddings, no_embeddings = self.graph_layer(c_it, n_it)
                output_it, h_t = self.transition_layer(t, co_embeddings, d_it, no_embeddings_i_prev, h_t)
                no_embeddings_i_prev = no_embeddings
                output_i.append(output_it)
            output_i = self.attention(torch.vstack(output_i))
            output.append(output_i)
        output = torch.vstack(output)
        output = self.classifier(output)
        return output

## 4. Define functions for training and evaluation

### 4.1 Historical hot function

We re-use the `historical_hot()` function directly from the [original codebase](https://github.com/LuChang-CS/Chet/blob/master/train.py). The function will be used later in model training and evaluation.

In [None]:
def historical_hot(code_x, code_num, lens):
    result = np.zeros((len(code_x), code_num), dtype=int)
    for i, (x, l) in enumerate(zip(code_x, lens)):
        result[i] = x[l - 1]
    return result

### 4.2 Data loader function

We create a data_loader() function to load the data needed for training and evaluating the model. The function is based on the [data loding code](https://github.com/LuChang-CS/Chet/blob/master/train.py#L45-L52) from the original authors and also re-uses several of the data loading helper functions from [`utils.py`](https://github.com/LuChang-CS/Chet/blob/master/utils.py) in the original codebase.

In [None]:
import os
from utils import load_adj
from utils import EHRDataset
from utils import MultiStepLRScheduler

def data_loader(task, dataset_path):
    print('from {} for task {}:'.format(dataset_path, task))
    print('loading code adjacency matrix ...')
    code_adj = load_adj(dataset_path, device=device)
    code_num = len(code_adj)
    print('loading train data ...')
    train_data = EHRDataset(os.path.join(dataset_path, "train/"), label=task, batch_size=batch_size, shuffle=True, device=device)
    print('loading valid data ...')
    valid_data = EHRDataset(os.path.join(dataset_path, "valid/"), label=task, batch_size=batch_size, shuffle=False, device=device)
    print('loading test data ...')
    test_data = EHRDataset(os.path.join(dataset_path, "test/"), label=task, batch_size=batch_size, shuffle=False, device=device)

    return {
        'dataset_name': dataset_path.split('/')[1],
        'code_adj': code_adj, 
        'code_num': code_num, 
        'train_data': train_data, 
        'valid_data': valid_data, 
        'test_data': test_data, 
        }



### 4.3 Model training function

In [None]:
import time
from utils import format_time

def train_chet(path, task, output_size, evaluate_fn, code_adj, code_num, dropout_rate, 
               train_data, valid_data, init_lr, lrs, milestones, epochs, test_historical):
    loss_fn = torch.nn.BCELoss()
    
    # Keep the same model param storage path as the original paper
    param_path = os.path.join('data', 'params', path, task)
    if not os.path.exists(param_path):
        os.makedirs(param_path)

    # Keep the same model, optimizer, and scheduler as the original paper,
    # but slightly modified to leverage the new config dict
    model = Model(code_num=code_num, code_size=code_size,
                    adj=code_adj, graph_size=graph_size, hidden_size=hidden_size, t_attention_size=t_attention_size,
                    t_output_size=t_output_size,
                    output_size=output_size, dropout_rate=dropout_rate).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = MultiStepLRScheduler(optimizer, epochs, init_lr, milestones, lrs)

    # Keep the same param printing code as the original paper
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(pytorch_total_params)

    # Keep the same training loop code as the original paper, but note that
    # the train, valid, and test data will change based on the task and
    # dataset of the current loop
    epoch_lrs, valid_losses, mean_losses, time_costs, f1_scores, auc_or_topks = [], [], [], [], [], []
    for epoch in range(epochs):
        print('Epoch %d / %d:' % (epoch + 1, epochs))
        model.train()
        total_loss = 0.0
        total_num = 0
        steps = len(train_data)
        st = time.time()
        scheduler.step()
        current_lr = scheduler.lrs[epoch]
        for step in range(len(train_data)):
            optimizer.zero_grad()
            code_x, visit_lens, divided, y, neighbors = train_data[step]
            output = model(code_x, divided, neighbors, visit_lens).squeeze()
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * output_size * len(code_x)
            total_num += len(code_x)

            end_time = time.time()
            remaining_time = format_time((end_time - st) / (step + 1) * (steps - step - 1))
            print('\r    Step %d / %d, LR: %s, remaining time: %s, loss: %.4f'
                % (step + 1, steps, current_lr, remaining_time, total_loss / total_num), end='')
        train_data.on_epoch_end()
        et = time.time()
        time_cost = format_time(et - st)
        mean_loss = total_loss / total_num
        print('\r    Step %d / %d, LR: %s, time cost: %s, loss: %.4f' % (steps, steps, current_lr, time_cost, mean_loss))
        valid_loss, f1_score, auc_or_topk = evaluate_fn(model, valid_data, loss_fn, output_size, test_historical)
        torch.save(model.state_dict(), os.path.join(param_path, '%d.pt' % epoch))
        epoch_lrs.append(current_lr)
        valid_losses.append(valid_loss)
        mean_losses.append(mean_loss)
        time_costs.append(time_cost)
        f1_scores.append(f1_score)
        auc_or_topks.append(auc_or_topk)
        
    return {
        'model': model,
        'epoch_lrs': epoch_lrs,
        'valid_losses': valid_losses,
        'mean_losses': mean_losses,
        'time_costs': time_costs,
        'f1_scores': f1_scores,
        'auc_or_topks': auc_or_topks,
    }

### 4.4 Model evaluation function
We create a `test()` function to evaluate the model on the `test_data`. We re-use the `evaluate_codes()` and `evaluate_hf()` functions directly from the original [metrics.py](https://github.com/LuChang-CS/Chet/blob/master/metrics.py) class.

In [None]:
def test(evaluate_fn, model, test_data, loss_fn, output_size, test_historical):
    print("Evaluating model on test data...")
    model.eval()
    test_loss, f1_score, auc_or_topk = evaluate_fn(model, test_data, loss_fn, output_size, historical=test_historical)
    print("Test loss: %s, F1 score: %s, AUC or TopK: %s" % (test_loss, f1_score, auc_or_topk))
    return {
        'test_loss': test_loss,
        'f1_score': f1_score,
        'auc_or_topk': auc_or_topk,
    }

## 4. Load the preprocessed data

Here we load the preprocessed data using the data_loader() function defined above, for both the MIMIC-III and MIMIC-IV datasets.

In [None]:
# Todo: Add data analysis here to explain each sub datasets eg: code_x, visit_lens, divided, y, neighbors, code_adj. Also can give an example by printing some data
# Todo: Add some initial data analysis. 1) number/ratio of heart failure patients. 2) some statistic for neighbors
# Possible Todo: build our own data builder using torch.utils.data.DataLoader as HWs

tasks = ['h', 'm']
mimic4_standard_path = "data/mimic4/standard/"
mimic3_standard_path = "data/mimic3/standard/"
mimic3_datasets, mimic4_datasets = {}, {}
for task in tasks:
    mimic3_datasets[task] = data_loader(task, mimic3_standard_path)
    mimic4_datasets[task] = data_loader(task, mimic4_standard_path)
print("data loaded")

## 5. Train and test the model

In [None]:
from metrics import evaluate_codes, evaluate_hf

epochs = 20 # epochs = 200 in original paper

train_results = {}
test_results = {}

for task in tasks:
    dropout_rate_ = 0.45 if task == 'm' else 0.0
    lrs_ = [1e-3, 1e-5] if task == 'm' else [1e-3, 1e-4, 1e-5]
    milestones_ = [15, 17] if task == 'm' else [2,3,4] # [20, 30] [2, 3, 20]
    evaluate_fn_ = evaluate_codes if task == 'm' else evaluate_hf
    for dataset in [mimic3_datasets, mimic4_datasets]:
        output_size_ = dataset[task]['code_num'] if task == 'm' else 1
        print('training for %s task on %s dataset:' % (task, dataset[task]['dataset_name']))
        train_data_ = dataset[task]['train_data']
        valid_data_ = dataset[task]['valid_data']
        test_data_ = dataset[task]['test_data']
        train_results[dataset[task]['dataset_name'] + "-" + task] = train_chet(
            path=dataset[task]['dataset_name'],task=task, output_size=output_size_, 
            evaluate_fn=evaluate_fn_, code_adj=dataset[task]['code_adj'], code_num=dataset[task]['code_num'], 
            dropout_rate=dropout_rate_, train_data=train_data_, valid_data=valid_data_, 
            init_lr=0.01, lrs=lrs_, milestones=milestones_, epochs=epochs, 
            test_historical=historical_hot(valid_data_.code_x, dataset[task]['code_num'], valid_data_.visit_lens)
            )
        test_results[dataset[task]['dataset_name'] + "-" + task] = test(
            evaluate_fn=evaluate_fn_, model=train_results[dataset[task]['dataset_name'] + "-" + task]['model'], 
            test_data=test_data_, loss_fn=torch.nn.BCELoss(), output_size=output_size_,
            test_historical=historical_hot(test_data_.code_x, dataset[task]['code_num'], test_data_.visit_lens)
            )

In [None]:
# analyze the results here

print("test_results = ", test_results)
print("train_results = ", train_results)

# import matplotlib.pylab as plt
# plt.plot(x, y)
# plt.show()