### This notebook reproduces:

[Variationally Regularized Graph-based Representation Learning for Electronic Health Records](https://arxiv.org/abs/1912.03761).

The code in this notebook is adapted from the original code included with the paper.

[Paper Code](https://github.com/NYUMedML/GNN_for_EHR)


In [5]:
import argparse
import torch
import numpy as np
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from collections import Counter
import pickle
from tqdm import tqdm
from datetime import datetime
import os
import logging
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, auc
from torch.utils.data import Dataset
import copy

In [6]:
# Verify Nvidia Drivers setup (optional uncomment if desired)
!nvidia-smi

/bin/bash: /home/rtikes/anaconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Tue Apr 25 07:53:09 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0  On |                  N/A |
|  0%   41C    P8    23W / 170W |    167MiB / 12288MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                              

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [8]:
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

NVIDIA GeForce RTX 3060


In [11]:
# Verify available memory on GPU
print(torch.cuda.memory_summary(device=0, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   5882 KiB |   5882 KiB |   5882 KiB |      0 B   |
|       from large pool |   5296 KiB |   5296 KiB |   5296 KiB |      0 B   |
|       from small pool |    585 KiB |    585 KiB |    585 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |   5882 KiB |   5882 KiB |   5882 KiB |      0 B   |
|       from large pool |   5296 KiB |   5296 KiB |   5296 KiB |      0 B   |
|       from small pool |    585 KiB |    585 KiB |    585 KiB |      0 B   |
|---------------------------------------------------------------

In [9]:
# model

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def clone_params(param, N):
    return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])


class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class GraphLayer(nn.Module):

    def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
                 num_of_heads, dropout, alpha, concat=True):
        super(GraphLayer, self).__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.num_of_nodes = num_of_nodes
        self.num_of_heads = num_of_heads
        self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)
        self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads)
        self.ffn = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU()
        )
        if not concat:
            self.V = nn.Linear(hidden_features, out_features)
        else:
            self.V = nn.Linear(num_of_heads * hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        if concat:
            self.norm = LayerNorm(hidden_features)
        else:
            self.norm = LayerNorm(hidden_features)

    def initialize(self):
        for i in range(len(self.W)):
            nn.init.xavier_normal_(self.W[i].weight.data)
        for i in range(len(self.a)):
            nn.init.xavier_normal_(self.a[i].data)
        if not self.concat:
            nn.init.xavier_normal_(self.V.weight.data)
            nn.init.xavier_normal_(self.out_layer.weight.data)

    def attention(self, linear, a, N, data, edge):
        data = linear(data).unsqueeze(0)
        assert not torch.isnan(data).any()
        # edge: 2*D x E
        h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]), dim=0)
        data = data.squeeze(0)
        # h: N x out
        assert not torch.isnan(h).any()
        # edge_h: 2*D x E
        edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1)
        # edge: 2*D x E
        edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads))
        assert not torch.isnan(edge_e).any()
        # edge_e: E
        edge_e = torch.sparse_coo_tensor(edge.to(device), edge_e.to(device), torch.Size([N, N]))
        e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device))
        # e_rowsum: N x 1
        row_check = (e_rowsum == 0)
        e_rowsum[row_check] = 1
        zero_idx = row_check.nonzero()[:, 0]
        edge_e = edge_e.add(
            torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N])))
        # edge_e: E
        h_prime = torch.sparse.mm(edge_e, data)
        assert not torch.isnan(h_prime).any()
        # h_prime: N x out
        h_prime.div_(e_rowsum)
        # h_prime: N x out
        assert not torch.isnan(h_prime).any()
        return h_prime

    def forward(self, edge, data=None):
        N = self.num_of_nodes
        if self.concat:
            h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1)
        else:
            h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
                dim=0)
        h_prime = self.dropout(h_prime)
        if self.concat:
            return F.elu(self.norm(h_prime))
        else:
            return self.V(F.relu(self.norm(h_prime)))


class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features//2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features//2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        length = data.size()[0]
        nonzero = data.nonzero()
        if nonzero.size()[0] == 0:
            return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
        if self.training:
            mask = torch.rand(nonzero.size()[0])
            mask = mask > 0.05
            nonzero = nonzero[mask]
            if nonzero.size()[0] == 0:
                return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
        nonzero = nonzero.transpose(0, 1) + 1
        lengths = nonzero.size()[1]
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).transpose(0, 1)
                                 .contiguous().view((1, lengths ** 2))), dim=0)

        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1)
        lengths = nonzero.size()[1]
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).transpose(0, 1)
                                  .contiguous().view((1, lengths ** 2))), dim=0)
        return input_edges.to(device), output_edges.to(device)

    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def encoder_decoder(self, data):
        N = self.num_of_nodes
        input_edges, output_edges = self.data_to_edges(data)
        h_prime = self.embed(torch.arange(N).long().to(device))
        for attn in self.in_att:
            h_prime = attn(input_edges, h_prime)
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            h_prime = self.dropout(h_prime)
            mu = h_prime[:, 0, :]
            logvar = h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu = mu[data, :]
            logvar = logvar[data, :]
        h_prime = self.out_att(output_edges, h_prime)
        if self.variational:
            return h_prime[-1], 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0]
        else:
            return h_prime[-1], torch.tensor(0.0).to(device)

    def forward(self, data):
        # Concate batches
        batch_size = data.size()[0]
        # In eicu data the first feature whether have be admitted before is not included in the graph
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
            return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \
                   torch.sum(torch.stack([out[1] for out in outputs]))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            return self.out_layer(F.relu(
                torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]))
                             for out in outputs]))), \
                   torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1)

In [10]:
# Configuration
# output path of model checkpoints
result_path = '/ml/ehr/mimic/models/notebook-run'
# input path of processed dataset
data_path = '/ml/ehr/mimic/out2/'
embedding_size = 128
# number of graph layers
num_of_layers = 2
# number of attention heads
num_of_heads = 1
# learning rate
lr = 1e-4
batch_size = 10
dropout = 0.4
# regularization 
reg = True
# regularization
lbd = 1.0
in_feature = embedding_size
out_feature = embedding_size
n_layers = num_of_layers - 1
n_heads = num_of_heads
alpha = 0.1
number_of_epochs = 1
eval_freq = 1000

In [11]:
# Utils
def train(data, model, optim, criterion, lbd, max_clip_norm=5):
    model.train()
    input = data[:, :-1].to(device)
    label = data[:, -1].float().to(device)
    model.train()
    optim.zero_grad()
    logits, kld = model(input)
    logits = logits.squeeze(-1)
    kld = kld.sum()
    bce = criterion(logits, label)
    loss = bce + lbd * kld
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
    loss.backward()
    optim.step()
    return loss.item(), kld.item(), bce.item()


def evaluate(model, data_iter, length):
    model.eval()
    y_pred = np.zeros(length)
    y_true = np.zeros(length)
    y_prob = np.zeros(length)
    pointer = 0
    for data in data_iter:
        input = data[:, :-1].to(device)
        label = data[:, -1]
        batch_size = len(label)
        probability, _ = model(input)
        probability = torch.sigmoid(probability.squeeze(-1).detach())
        predicted = probability > 0.5
        y_true[pointer: pointer + batch_size] = label.numpy()
        y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
        y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
        pointer += batch_size
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    return auc(recall, precision), (y_pred, y_prob, y_true)


class EHRData(Dataset):
    def __init__(self, data, cla):
        self.data = data
        self.cla = cla

    def __len__(self):
        return len(self.cla)

    def __getitem__(self, idx):
        return self.data[idx], self.cla[idx]


def collate_fn(data):
    # padding
    data_list = []
    for datum in data:
        data_list.append(np.hstack((datum[0].toarray().ravel(), datum[1])))
    return torch.from_numpy(np.array(data_list)).long()

In [None]:
# Load data
train_x, train_y = pickle.load(open(data_path + 'train_csr.pkl', 'rb'))
val_x, val_y = pickle.load(open(data_path + 'validation_csr.pkl', 'rb'))
test_x, test_y = pickle.load(open(data_path + 'test_csr.pkl', 'rb'))
train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
train_x = train_x[train_upsampling]
train_y = train_y[train_upsampling]

# Create result root
s = datetime.now().strftime('%Y%m%d%H%M%S')
result_root = '%s/lr_%s-input_%s-output_%s-dropout_%s'%(result_path, lr, in_feature, out_feature, dropout)
if not os.path.exists(result_root):
    os.mkdir(result_root)
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(filename='%s/train.log' % result_root, format='%(asctime)s %(message)s', level=logging.INFO)
logging.info("Time:%s" %(s))

# initialize models
num_of_nodes = train_x.shape[1] + 1
device_ids = range(torch.cuda.device_count())
    
# eICU has 1 feature on previous readmission that we didn't include in the graph
model = VariationalGNN(in_feature, out_feature, num_of_nodes, n_heads, n_layers,
                           dropout=dropout, alpha=alpha, variational=reg, none_graph_features=0).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=batch_size,
                            collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=False)
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Train models
for epoch in range(number_of_epochs):
    print("Learning rate:{}".format(optimizer.param_groups[0]['lr']))
    ratio = Counter(train_y)
    train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                                  collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=True)
    pos_weight = torch.ones(1).float().to(device) * (ratio[True] / ratio[False])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    for idx, batch_data in enumerate(t):
        loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, 5)
        total_loss += np.array([loss, bce, kld])
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), "{}/parameter{}_{}".format(result_root, epoch, idx))
            val_auprc, _ = evaluate(model, val_loader, len(val_y))
            logging.info('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' %
                             (epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
            print('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' %
                    (epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
        if idx % 50 == 0 and idx > 0:
            t.set_description('[epoch:%d] loss: %.4f, bce: %.4f, kld: %.4f' %
                                  (epoch + 1, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
            t.refresh()
    scheduler.step()