In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import pickle
import matplotlib.pyplot as plt
from collections import defaultdict

In [None]:
plt.style.use("seaborn-darkgrid")

In [None]:
from utils.stats_model import LoadStatsSurvivalModels, StatsSurvivalModelsProcessor
from data_and_preprocessing.dfs_generator import Gbsg2Generator, RecurGenerator, LymphGenerator, CaliforniaHousingGenerator

In [None]:
dataset_names = ["gbsg2", "recur", "lymph", "california"]

# Set the dataset for which we will evaluate the models.

In [None]:
dataset_name = dataset_names[0]
print(dataset_name)

# Load dataset & models

In [None]:
df_generator = pickle.load(open(f"data_and_preprocessing/df_generator_{dataset_name}.pickle", "rb" ))

In [None]:
stats_survival_models = LoadStatsSurvivalModels()(df_generator.name, resolution=10)
stats_survival_models = {name: stats_survival_models[name] for name in sorted(stats_survival_models.keys())}

# Compute Metrics

In [None]:
%%time

processor_valid = StatsSurvivalModelsProcessor(
    stats_survival_models=stats_survival_models,
    df_generator=df_generator,
    part="valid",
)
df_fcm_valid = processor_valid.get_full_curve_metrics()

In [None]:
%%time

processor_test = StatsSurvivalModelsProcessor(
    stats_survival_models=stats_survival_models,
    df_generator=df_generator,
    part="test",
)
df_fcm_test = processor_test.get_full_curve_metrics()

# Find best model of each group of models

In [None]:
def get_groups():
    model_names = set(df_fcm_test.index)
    model_groups = []

    classical_models = {"Cox_piecewise", "Cox_spline", "LogLogistic", "LogNormal", "Weibull"}
    classical_models = {name for name in classical_models if name in model_names}

    model_names = model_names - classical_models - {"KaplanMeier"}

    model_groups = defaultdict(set)

    model_groups["KaplanMeier"] = {"KaplanMeier"}
    if len(classical_models) != 0:
        model_groups["classical_models"] = classical_models

    for name in model_names:
        spl = name.split("_")
        assert len(spl) == 2, "Not exactly one _ in name."
        model_groups[spl[0]].add(name)

    return model_groups


def get_best_of_group(model_group):
    group_means = df_fcm_valid["Mean"][list(model_group)]
    return group_means.idxmax()


def get_names_of_relevant_models():
    model_groups = get_groups()
    model_names = [get_best_of_group(model_group) for _, model_group in model_groups.items()]
    return sorted(model_names)


model_names = get_names_of_relevant_models()

# Tables & plots for the best models

In [None]:
df_fcm_valid_short = df_fcm_valid.loc[model_names, :]
processor_valid.print_full_curve_metrics(df=df_fcm_valid_short)

In [None]:
df_fcm_test_short = df_fcm_test.loc[model_names, :]
processor_test.print_full_curve_metrics(df=df_fcm_test_short)

In [None]:
processor_valid.plot_metrics(model_names=model_names)

In [None]:
processor_test.plot_metrics(model_names=model_names)

# Print and save tables for all models

In [None]:
processor_valid.print_full_curve_metrics()
processor_test.print_full_curve_metrics()

In [None]:
df_all_metrics_valid = processor_valid.get_full_curve_metrics()
df_all_metrics_test = processor_test.get_full_curve_metrics()

df_all_metrics_valid.to_pickle(f"plotting/{dataset_name}_valid.pickle")
df_all_metrics_test.to_pickle(f"plotting/{dataset_name}_test.pickle")