In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

import os, sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from imv_lstm import IMVFullLSTM
from descriptors import descriptors

In [None]:
# Load and process data
def load_dataset():
	Y = []
	X = []
	
	with open("icu-mortality-data/outcomes.csv", "r") as y_file:
		next(y_file)
		for line in y_file:
			y_values = line.split(',')
			x_name, fatality = y_values[0], int(y_values[5])

			Y.append(fatality)
			X_example = torch.zeros((48, len(descriptors)))
			with open("icu-mortality-data/" + x_name + ".txt", "r") as x_file:
				next(x_file)
				next(x_file)

				for line2 in x_file:
					try:
						x_values = line2.split(',')
						hour = int(x_values[0][1])
						descriptor = x_values[1]
						value = float(x_values[2].strip('\n'))

						X_example[hour][descriptors.index(descriptor)] = value
					except ValueError:
						# Invalid descriptor
						continue
			X.append(X_example)
	return torch.stack(X), torch.FloatTensor(Y)

X, Y = load_dataset()

In [None]:
# Split data into respective groups
X_train, Y_train = X[0:3900], Y[0:3900]
X_test, Y_test = X[3900:4000], Y[3900:4000]
train_loader = DataLoader(TensorDataset(X_train, Y_train), batch_size=64, shuffle=True)

# Define model
device = torch.device("cuda")
model = IMVFullLSTM(device, X.shape[2], 1, 128).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)
loss = nn.BCEWithLogitsLoss()
epoch_scheduler = torch.optim.lr_scheduler.StepLR(opt, 20, gamma=0.9)

In [None]:
# Perform training
epochs = 50
for i in range(epochs):
    mse_train = 0
    for batch_x, batch_y in train_loader:
        batch_x = batch_x.cuda()
        batch_y = batch_y.cuda()
        opt.zero_grad()
        y_pred, alphas, betas = model(batch_x)
        y_pred = y_pred.squeeze(1)
        l = loss(y_pred, batch_y)
        l.backward()
        mse_train += l.item()*batch_x.shape[0]
        opt.step()
    print("Epoch " + str(i) + " complete, loss: " + str(l.item()))
    epoch_scheduler.step()
    
print("Saving...")
torch.save(model.state_dict(), "out/model_weights.pt")

In [None]:
# Reload model using the CPU
device = torch.device("cpu")
model = IMVFullLSTM(device, X.shape[2], 1, 128).to(device)
model.load_state_dict(torch.load("../model_weights.pt", map_location=device))

In [None]:
# Perform CPU Inference
index = 98
test_x = torch.unsqueeze(X_test[index], dim=0)
with torch.no_grad():
    output, a, b = model(test_x)
    
output = output.detach()
print(output)
print(torch.sigmoid(output))
print(Y_test[index])
# print(test_x[0, :, descriptors.index("HR")])

a = torch.squeeze(a).cpu().T
b = torch.squeeze(b).cpu()

In [None]:
fig, ax = plt.subplots(figsize=(20, 20))
im = ax.imshow(a)
ax.set_xticks(torch.arange(test_x.shape[1]))
ax.set_yticks(torch.arange(len(descriptors)))
ax.set_xticklabels([i.item() for i in torch.arange(1, test_x.shape[1]+1)])
ax.set_yticklabels(descriptors)
ax.set_title("Importance of features and timesteps")
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
plt.title("Feature importance")
plt.bar(range(len(descriptors)), b)
plt.xticks(ticks=range(len(descriptors)), labels=descriptors, rotation=90)