In [None]:
import random
import numpy as np
import pandas as pd
import torch
import config as cfg
from data_loader import get_data_loader
from torchmtlr.utils import make_time_bins
from utility.survival import preprocess_data

matplotlib_style = 'default'
import matplotlib.pyplot as plt; plt.style.use(matplotlib_style)
plt.rcParams.update({'axes.labelsize': 'medium',
                     'axes.titlesize': 'medium',
                     'font.size': 14.0,
                     'text.usetex': True,
                     'text.latex.preamble': r'\usepackage{amsfonts} \usepackage{bm}'})

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Setup precision
dtype = torch.float64
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
dl = get_data_loader("proact_me").load_data()
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1,
                                                  test_size=0.2, random_state=0)
n_events = dl.n_events

# Preprocess data
cat_features = dl.cat_features
num_features = dl.num_features
event_cols = [f'e{i+1}' for i in range(n_events)]
time_cols = [f't{i+1}' for i in range(n_events)]
X_train = pd.DataFrame(train_dict['X'], columns=dl.columns)
X_valid = pd.DataFrame(valid_dict['X'], columns=dl.columns)
X_test = pd.DataFrame(test_dict['X'], columns=dl.columns)
X_train, X_valid, X_test = preprocess_data(X_train, X_valid, X_test, cat_features,
                                           num_features, as_array=False)
feature_names = X_train.columns
n_features = train_dict['X'].shape[1]
train_dict['X'] = torch.tensor(X_train.to_numpy(), device=device, dtype=dtype)
train_dict['E'] = torch.tensor(train_dict['E'], device=device, dtype=torch.int64)
train_dict['T'] = torch.tensor(train_dict['T'], device=device, dtype=torch.int64)
valid_dict['X'] = torch.tensor(X_valid.to_numpy(), device=device, dtype=dtype)
valid_dict['E'] = torch.tensor(valid_dict['E'], device=device, dtype=torch.int64)
valid_dict['T'] = torch.tensor(valid_dict['T'], device=device, dtype=torch.int64)
test_dict['X'] = torch.tensor(X_test.to_numpy(), device=device, dtype=dtype)
test_dict['E'] = torch.tensor(test_dict['E'], device=device, dtype=torch.int64)
test_dict['T'] = torch.tensor(test_dict['T'], device=device, dtype=torch.int64)

# Make time bins
time_bins = make_time_bins(train_dict['T'].cpu(), event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

In [6]:
from mensa.model import MENSA
from utility.config import load_config

config = load_config(cfg.MENSA_CONFIGS_DIR, "proact.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
trajectories = config['trajectories']
weight_decay = config['weight_decay']

# Train 4 models with n_dists = 1, 3, 5, 10
trained_models = []
for n_dists in [1]: #, 3, 5, 10
    model = MENSA(n_features, layers=layers, n_events=n_events,
                  n_dists=n_dists, trajectories=trajectories, device=device)
    model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
              weight_decay=weight_decay, patience=10, batch_size=batch_size,
              verbose=False)
    trained_models.append(model)

KeyError: 'trajectories'