In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from time import time
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
from collections import Counter
from sklearn.metrics import r2_score
from tensorboardX import SummaryWriter
% matplotlib inline

In [4]:
train_set = pd.read_csv('path/to/trainset')
test_set = pd.read_csv('/path/to/testset')

In [51]:
# patient observations
observ_cols = ['gender', 'age','elixhauser','re_admission', 'SOFA', 'SIRS', 'Weight_kg', 'GCS', 'HR',
            'SysBP', 'MeanBP', 'DiaBP', 'RR', 'SpO2',
            'Temp_C', 'FiO2_1', 'Potassium', 'Sodium', 'Chloride',
            'Glucose', 'BUN', 'Creatinine', 'Magnesium', 'Calcium',
            'Ionised_Ca', 'CO2_mEqL', 'SGOT', 'SGPT', 'Total_bili',
            'Albumin', 'Hb', 'WBC_count', 'Platelets_count', 'PTT',
            'PT', 'INR', 'Arterial_pH', 'paO2', 'paCO2',
            'Arterial_BE', 'Arterial_lactate', 'HCO3', 'PaO2_FiO2',
            'output_total', 'output_4hourly']

In [52]:
class autoEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        
        super(autoEncoder, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.encoder = nn.LSTM(input_size, hidden_size, dropout=0.2)
        self.decoder = nn.LSTM(hidden_size, input_size, dropout=0.2)
        
        self.init_hidden()
    
    def init_hidden(self):
        
        self.encode_hidden = (Variable(torch.zeros(1, 1, self.hidden_size)), \
                              Variable(torch.zeros(1, 1, self.hidden_size)))
        self.decode_hidden = (Variable(torch.zeros(1, 1, self.input_size)), \
                              Variable(torch.zeros(1, 1, self.input_size)))
    
    def forward(self, x):
        encoded, self.encode_hidden = self.encoder(x, self.encode_hidden)
        repeat_last_encoded = encoded[-1].expand(x.size(0), 1, -1)
        decoded, self.decode_hidden = self.decoder(repeat_last_encoded, self.decode_hidden)
        return encoded.squeeze(1), decoded.squeeze(1)

In [73]:
def train_autoencoder(train_set, test_set, autoencoder, lr=0.001, batch_size=128, num_epoch=50, print_every=10, val=False):
    
    uids = np.unique(train_set['icustayid'])
    
    enc_criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr)
    
    train_uids, val_uids = train_test_split(uids, test_size=0.1, random_state=42)

    for epoch in range(1, num_epoch+1):
        
        num_batches = train_uids.shape[0] // batch_size
        
        for batch in range(num_batches):
            
            enc_loss, enc_acc = 0, 0

            batch_uids = train_uids[batch*batch_size: (batch+1)*batch_size]
            batch_train_set = train_set[train_set['icustayid'].isin(batch_uids)]
            
            for i, uid in enumerate(batch_uids):

                autoencoder.init_hidden()
                patient = train_set[train_set['icustayid'] == uid]
                enc_X = Variable(torch.FloatTensor(patient[observ_cols].values))
                encoded, decoded = autoencoder(enc_X.unsqueeze(1))

                enc_loss += enc_criterion(decoded, enc_X)
                enc_acc += r2_score(enc_X.data.numpy(), decoded.squeeze(1).data.numpy(), 
                                    multioutput='variance_weighted')
            
            enc_loss /= batch_size
            enc_acc /= batch_size
            
            if batch != 0 and batch % print_every == 0:
                print ('epoch:{}/{}, batch:{}/{}, loss:{}, enc_acc:{}'.format(epoch, 
                                                                              num_epoch,batch, \
                                                                              num_batches, \
                                                                              enc_loss.data[0], \
                                                                              enc_acc))
            optimizer.zero_grad()
            enc_loss.backward()
            optimizer.step()
        
        if val:
            print ('-----------------------')
            print ('evaluating ...')
            val_total_loss, val_enc_acc = do_eval(train_set, val_uids, autoencoder)
            print ('Validating: loss:{}, enc_acc:{}'.format(val_total_loss.data[0], val_enc_acc))
            print ('-----------------------')
        
        if epoch != 0 and epoch % 3 == 0:
            print ('Testing ...')
            do_eval(test_set, None, autoencoder)
            print ('-----------------------')
        
        autoencoder.train()

In [74]:
def do_eval(eval_set, eval_uids=None, autoencoder, output_embeddings=False):
    
    autoencoder.eval()
    enc_criterion = torch.nn.MSELoss()
    
    if eval_uids is None:
         eval_uids = np.unique(eval_set['icustayid'])

    eval_enc_loss = 0
    eval_enc_acc = 0
    
    embeddings = []

    for uid in eval_uids:

        autoencoder.init_hidden()
        patient = eval_set[eval_set['icustayid'] == uid]
        enc_X = Variable(torch.FloatTensor(patient[observ_cols].values))
        encoded, decoded = autoencoder(enc_X.unsqueeze(1)) 

        eval_enc_loss += enc_criterion(decoded, enc_X)
        eval_enc_acc += r2_score(enc_X.data.numpy(), decoded.squeeze(1).data.numpy(), 
                                    multioutput='variance_weighted')
        if output_embeddings:
            embeddings += encoded.data.numpy().tolist()
        
    eval_enc_acc /= eval_uids.shape[0]
    
    if output_embeddings:
        return embeddings
    
    return eval_enc_loss, eval_enc_acc

In [10]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [69]:
autoencoder = autoEncoder(45, 128)

In [77]:
train_autoencoder(train_set, test_set, autoencoder, num_epoch=50, val=True) # 50 epoch, 0.94, 0.89

In [None]:
save_model(autoencoder, 'path/to/model')

In [None]:
test_embeddings = do_eval(test_set, autoencoder, output_embeddings=True)