In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import glob
import pickle
import random
import numpy as np

In [None]:
from nets.survival_net import SurvivalNN
from nets.cox_nn import CoxNN, CoxTimeDependentNN
from utils.nn_utils import NetTrainer, WeightedBCELoss, SuMoLoss
from nets.monotone_module import MonotonicIncreasingVectorNet, MonotonicIncreasingNet
from data_and_preprocessing.dfs_generator import Gbsg2Generator, RecurGenerator, LymphGenerator, CaliforniaHousingGenerator

In [None]:
def get_loss_postfix(sumo):
    if sumo:
        return "SUMO"
    return "BCE"

# Training parameters

In [None]:
glob_total_todo = 5
n_training_steps = 200000
model_via_moving_average_on_validation = 512

In [None]:
def get_bce_weight(dataset_name):
    if dataset_name == "gbsg2":
        return 0.71

    if dataset_name == 'recur':
        return 0.85

    if dataset_name == 'lymph':
        return 0.86

    if dataset_name == 'california':
        return 0.53

    raise ValueError("Bad dataset name.")

In [None]:
def get_bce_std_factor(dataset_name):
    if dataset_name == "gbsg2":
        return 0.82

    if dataset_name == 'recur':
        return 0.96

    if dataset_name == 'lymph':
        return 0.79

    if dataset_name == 'california':
        return 0.5

    raise ValueError("Bad dataset name.")

In [None]:
def get_sumo_weight(dataset_name):
    if dataset_name == "gbsg2":
        return 2.7

    if dataset_name == 'recur':
        return 0.87

    if dataset_name == 'lymph':
        return 3.44

    if dataset_name == 'california':
        return 0.89

    raise ValueError("Bad dataset name.")

In [None]:
def get_sumo_weight_decay(dataset_name, model_name):
    if model_name not in ["CoxNN", "CoxTimeDependentNN"]:
        return 0

    if dataset_name == "gbsg2":
        return 0.005

    if dataset_name == 'recur':
        return 0.001

    if dataset_name == 'lymph':
        return 0.004

    if dataset_name == 'california':
        return 0.009

    assert False, f"Bad inputs in get_sumo_weight_decay {dataset_name=}, {model_name=}"

In [None]:
def get_bce_weight_decay(dataset_name, model_name):
    if model_name not in ["CoxNN", "CoxTimeDependentNN"]:
        return 0

    if dataset_name == "gbsg2":
        return 0.02

    if dataset_name == 'recur':
        return 0.001

    if dataset_name == 'lymph':
        return 0.002

    if dataset_name == 'california':
        return 0.005

    assert False, f"Bad inputs in get_bce_weight_decay {dataset_name=}, {model_name=}"

# General training code

In [None]:
def get_rand():
    return np.random.randint(100000, 1000000)

In [None]:
def get_criterion(df_generator, sumo):
    dataset_name = df_generator.name
    if sumo:
        weight = get_sumo_weight(dataset_name)
        return SuMoLoss(weight=weight)

    weight = get_bce_weight(dataset_name)

    σ_gaussian_delta_factor = get_bce_std_factor(dataset_name)
    σ_gaussian_delta = σ_gaussian_delta_factor * df_generator.max_horizon
    return WeightedBCELoss(σ_gaussian_delta=σ_gaussian_delta, weight=weight)

In [None]:
def get_net_trainer(dataset_name, weight_decay, sumo):
    df_generator = pickle.load(open(f"data_and_preprocessing/df_generator_{dataset_name}.pickle", "rb" ))
    criterion = get_criterion(df_generator, sumo)

    return NetTrainer(
        df_generator=df_generator,
        patience_factor=16,
        batch_size=8,
        model_via_moving_average_on_validation=model_via_moving_average_on_validation,
        criterion=criterion,
        clip=1,
        lr=1e-3,
        weight_decay=weight_decay,
        n_training_steps=n_training_steps,
    )

In [None]:
def train(model, model_name, net_trainer, rand):
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    name = f"{model_name}_{rand}"

    if how_many_left_to_do(net_trainer.df_generator.name, model_name, total_todo=5) < 0:
        return

    model, best_val_score, model_dict = net_trainer.train_and_save(name=name, model=model)
    return model

In [None]:
def run_training(dataset_name, get_model, sumo):
    model_name = get_model(None, only_name=True)

    if sumo:
        weight_decay = get_sumo_weight_decay(dataset_name, model_name)
    else:
        weight_decay = get_bce_weight_decay(dataset_name, model_name)

    model_name = get_model(None, only_name=True) +  "-" + get_loss_postfix(sumo)
    rand = get_rand()

    dummy_file_path = f'trained_models/{dataset_name}/{model_name}_{rand}.dummy'
    open(dummy_file_path, 'wb').close()

    done = False
    while not done:
        try:
            net_trainer = get_net_trainer(dataset_name=dataset_name, weight_decay=weight_decay, sumo=sumo)
            model, _ = get_model(net_trainer)
            model = train(model, model_name=model_name, net_trainer=net_trainer, rand=rand)
            done = True
            os.remove(dummy_file_path)
        except:
            print("Training failed for:", dummy_file_path)

# Cox

In [None]:
def get_CoxNN(net_trainer, only_name=False):
    model_name = "CoxNN"
    if only_name:
        return model_name

    monotonic_increasing_net = MonotonicIncreasingNet(latent_sizes=[32]*5)
    model = CoxNN(
        n_input_features=net_trainer.n_input_features,
        monotonic_increasing_net=monotonic_increasing_net,
        t_scaling=net_trainer.horizon,
    )
    return model, model_name

# CoxTimeDependentNN

In [None]:
def get_CoxTimeDependentNN(net_trainer, only_name=False):
    model_name = "CoxTimeDependentNN"
    if only_name:
        return model_name

    monotonic_increasing_net = MonotonicIncreasingNet(latent_sizes=[32]*5)
    monotonic_increasing_net_coefficients = MonotonicIncreasingVectorNet(latent_sizes=[32]*5 + [net_trainer.n_input_features,])

    model = CoxTimeDependentNN(
        n_input_features=net_trainer.n_input_features,
        monotonic_increasing_net_baseline=monotonic_increasing_net,
        monotonic_increasing_net_coefficients=monotonic_increasing_net_coefficients,
        t_scaling=net_trainer.horizon,
    )
    return model, model_name

# SurvivalNN

In [None]:
def get_SurvivalNN(net_trainer, only_name=False):
    model_name = "SurvivalNN"
    if only_name:
        return model_name

    model = SurvivalNN(
        n_input_features=net_trainer.n_input_features,
        n_latent_features=32,
        t_scaling=net_trainer.horizon
    )
    return model, model_name

# Train

In [None]:
def get_model_getters():
    getter_fcts = [
        get_CoxNN, get_CoxTimeDependentNN, get_SurvivalNN,
    ]

    dict_getter_fcts = {
        getter_fct(None, only_name=True): getter_fct for getter_fct in getter_fcts
    }

    return dict_getter_fcts

In [None]:
def get_dataset_names():
    return ["gbsg2", "recur", "lymph", "california"]

In [None]:
def get_all_model_dataset_combinations(dataset_names):
    combis = []

    for dataset_name in dataset_names:
        for model_name, get_model in get_model_getters().items():
            combis.append((dataset_name, model_name))

    return combis

In [None]:
def how_many_left_to_do(dataset_name, model_name, total_todo):
    files_in_folder = glob.glob(f'trained_models/{dataset_name}/*')
    model_files_in_folder = [file_name for file_name in files_in_folder if f"{model_name}_" in file_name]
    return total_todo - len(model_files_in_folder)

In [None]:
def get_todo_model_dataset_combinations(dataset_names, sumo, total_todo=glob_total_todo):
    all_combis = get_all_model_dataset_combinations(dataset_names)

    todo_combis = []

    for dataset_name, model_name in all_combis:
        left_over = how_many_left_to_do(dataset_name, model_name + "-" + get_loss_postfix(sumo), total_todo=total_todo)

        if left_over > 0:
            todo_combis.append((dataset_name, model_name))

    return todo_combis

In [None]:
def run_worker(dataset_names, sumo):
    while True:
        todo_combis = get_todo_model_dataset_combinations(dataset_names, sumo=sumo)

        if len(todo_combis) <= 0:
            return

        dataset_name, model_name = random.choice(todo_combis)
        get_model = get_model_getters()[model_name]
        print(f"{dataset_name=}, {model_name=}")
        run_training(dataset_name, get_model, sumo=sumo)

In [None]:
def run_cpu_worker(sumo):
    run_worker(["gbsg2", "recur", "lymph", "california"], sumo)

## Run the trainings

In [None]:
for train_with_sumo_loss in [False, True]:
    run_cpu_worker(sumo=train_with_sumo_loss)