In [46]:
import pandas as pd
import numpy as np
import config as cfg
import torch
import torch.optim as optim
import torch.nn as nn
import random
import warnings
import copy
import tqdm
import math
import argparse

from utility.survival import (make_time_bins, preprocess_data, convert_to_structured,
                              risk_fn, compute_l1_difference, predict_survival_function,
                              make_times_hierarchical)
from utility.data import (dotdict, format_data, format_data_as_dict_single)
from utility.config import load_config
from utility.evaluation import global_C_index, local_C_index
from data_loader import SeerDataLoader

from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceMulti
from survtrace.train_utils import Trainer
from survtrace.config import STConfig
from utility.data import calculate_vocab_size

from data_loader import SeerDataLoader

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Set precision
dtype = torch.float64
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and split data
dl = SeerDataLoader().load_data(n_samples=10000, device=device, dtype=dtype)
df_train, df_valid, df_test = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                            random_state=0)
n_events = dl.n_events

# Preprocess data
cat_features = dl.cat_features
num_features = dl.num_features
X_train = df_train.drop(['event', 'time'], axis=1)
X_valid = df_valid.drop(['event', 'time'], axis=1)
X_test = df_test.drop(['event', 'time'], axis=1)
X_train, X_valid, X_test= preprocess_data(X_train, X_valid, X_test, cat_features,
                                            num_features, as_array=True)
train_dict = format_data_as_dict_single(X_train, df_train['event'], df_train['time'], dtype)
valid_dict = format_data_as_dict_single(X_valid, df_valid['event'], df_valid['time'], dtype)
test_dict = format_data_as_dict_single(X_test, df_test['event'], df_test['time'], dtype)
n_samples = train_dict['X'].shape[0]
n_features = train_dict['X'].shape[1]

# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype)

In [47]:
from pycox.preprocessing.label_transforms import LabTransDiscreteTime

def format_survtrace_data(train_dict, valid_dict, time_bins, n_events):
    class LabTransform(LabTransDiscreteTime):
        def transform(self, durations, events):
            durations, is_event = super().transform(durations, events > 0)
            events[is_event == 0] = 0
            return durations, events.astype('int64')
    train_dict_dh = dict()
    train_dict_dh['X'] = train_dict['X'].numpy()
    train_dict_dh['E'] = train_dict['E'].numpy()
    train_dict_dh['T'] = train_dict['T'].numpy()
    valid_dict_dh = dict()
    valid_dict_dh['X'] = valid_dict['X'].numpy()
    valid_dict_dh['E'] = valid_dict['E'].numpy()
    valid_dict_dh['T'] = valid_dict['T'].numpy()
    labtrans = LabTransform(time_bins.numpy())
    get_target = lambda data: (data['T'], data['E'])
    y_train = labtrans.transform(*get_target(train_dict_dh))
    y_valid = labtrans.transform(*get_target(valid_dict_dh))
    out_features = int(labtrans.out_features)
    duration_index = labtrans.cuts
    y_train_df, y_valid_df = pd.DataFrame(), pd.DataFrame()
    y_train_df['duration'] = y_train[0]
    y_train_df['proportion'] = y_train[1]
    y_valid_df['duration'] = y_valid[0]
    y_valid_df['proportion'] = y_valid[1]
    for i in range(n_events):
        event_name = "event_{}".format(i)
        y_train_df[event_name] = (y_train[1] == i+1)*1.0
        y_valid_df[event_name] = (y_valid[1] == i+1)*1.0
    return y_train_df, y_valid_df, duration_index, out_features

In [48]:
SURVTRACE_PARAMS = {
    "num_hidden_layers": 1,
    "hidden_size": 32,
    "intermediate_size": 32,
    "num_attention_heads": 2,
    "initializer_range": .02,
    "batch_size": 128,
    "weight_decay": 0,
    "learning_rate": 1e-4,
    "epochs": 100,
    "early_stop_patience": 10,
    "hidden_dropout_prob": 0.25,
    "seed": 0,
    "hidden_act": "gelu",
    "attention_probs_dropout_prob": 0.25,
    "layer_norm_eps": 1000000000000,
    "checkpoint": "./checkpoints/survtrace.pt",
    "max_position_embeddings": 512,
    "chunk_size_feed_forward": 0,
    "output_attentions": False,
    "output_hidden_states": False,
    "tie_word_embeddings": True,
    "pruned_heads": {}
}

In [49]:
# Train survtrace
config = dotdict(SURVTRACE_PARAMS)
X_train = pd.DataFrame(train_dict['X'], columns=[f'X{i}' for i in range(n_features)])
X_valid = pd.DataFrame(valid_dict['X'], columns=[f'X{i}' for i in range(n_features)])
cat_features = []
num_features = [f'X{i}' for i in range(n_features)]
y_train, y_valid, duration_index, out_features = format_survtrace_data(train_dict, valid_dict,
                                                                        time_bins, n_events)
config['vocab_size'] = calculate_vocab_size(X_train, cat_features)
config['duration_index'] = duration_index
config['out_feature'] = out_features
config['num_numerical_feature'] = int(len(num_features))
config['num_categorical_feature'] = int(len(cat_features))
config['num_feature'] = n_features
config['num_event'] = n_events
config['in_features'] = n_features
model = SurvTraceMulti(dotdict(config))
trainer = Trainer(model)
trainer.fit((X_train, y_train), (X_valid, y_valid),
            batch_size=config['batch_size'],
            epochs=config['epochs'],
            learning_rate=config['learning_rate'],
            weight_decay=config['weight_decay'],
            val_batch_size=config['batch_size'])

EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
EarlyStopping counter: 5 out of 10
EarlyStopping counter: 6 out of 10
EarlyStopping counter: 7 out of 10
EarlyStopping counter: 8 out of 10
EarlyStopping counter: 9 out of 10
EarlyStopping counter: 10 out of 10
early stops at epoch 13


([55.97630148985181,
  6.966451897985583,
  1.5373863616232106,
  1.5204206347874076,
  1.5199523139916784,
  1.5161057911014815,
  1.5179528190546678,
  1.5169070183157325,
  1.5181505190388975,
  1.514900581382812,
  1.5142344073398128,
  1.5158583275085333,
  1.5134622444611585],
 [54.10038864391313,
  1.5550540929984686,
  1.5339866592113212,
  1.5353198226450724,
  1.5367427431346887,
  1.5369599987068332,
  1.5415617924826341,
  1.5379400739079472,
  1.5405216051582071,
  1.538932658170876,
  1.5394357859127086,
  1.5396921084820625,
  1.5399577247838976])

In [50]:
# Make predictions
preds_e1 = model.predict_surv(test_dict['X'], batch_size=config['batch_size'], event=0)
preds_e2 = model.predict_surv(test_dict['X'], batch_size=config['batch_size'], event=1)

In [51]:
preds_e1

#TODO: Predictions are the same for all instances

tensor([[1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028],
        [1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028],
        [1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028],
        ...,
        [1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028],
        [1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028],
        [1.0000, 0.9994, 0.9986,  ..., 0.8225, 0.7343, 0.6028]])