In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import pickle
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import CoxPHFitter, WeibullAFTFitter, LogLogisticAFTFitter, LogNormalAFTFitter, KaplanMeierFitter

In [None]:
from data_and_preprocessing.dfs_generator import Gbsg2Generator, RecurGenerator, LymphGenerator, CaliforniaHousingGenerator

# Code

In [None]:
def save(name, data):
    with open(f'trained_models/{df_generator.name}/{name}.pickle', 'wb') as f:
        pickle.dump(data, f)


def save_model(name, model):
    data = {
        "name": name,
        "max_horizon": df_generator.max_horizon,
        "model": model,
    }
    save(name, data)

## KaplanMeierFitter

In [None]:
def get_KaplanMeierFitter():
    model = KaplanMeierFitter()

    model.fit(
        durations=dfs["train"].duration.copy(),
        event_observed=dfs["train"].event_observed.copy(),
    )

    return model

## CoxPHFitter

In [None]:
def get_CoxPHFitter(**kwargs):
    model = CoxPHFitter(**kwargs)

    model.fit(
        df=dfs["train"].copy(),
        duration_col="duration",
        event_col="event_observed",
    )

    return model

## WeibullAFTFitter

In [None]:
def get_WeibullAFTFitter():
    model = WeibullAFTFitter()

    model.fit(
        df=dfs["train"].copy(),
        duration_col="duration",
        event_col="event_observed",
    )

    return model

## LogLogisticAFTFitter

In [None]:
def get_LogLogisticAFTFitter():
    model = LogLogisticAFTFitter()

    model.fit(
        df=dfs["train"].copy(),
        duration_col="duration",
        event_col="event_observed",
    )

    return model

## LogNormalAFTFitter

In [None]:
def get_LogNormalAFTFitter():
    model = LogNormalAFTFitter()

    model.fit(
        df=dfs["train"].copy(),
        duration_col="duration",
        event_col="event_observed",
    )

    return model

## Start all applicable trainings

In [None]:
def train_all(resolution=10):
    km = get_KaplanMeierFitter()
    save_model("KaplanMeier", km)

    print("\t Weibull")
    weibull = get_WeibullAFTFitter()
    save_model("Weibull", weibull)

    print("\t LogLogistic")
    log_logistic = get_LogLogisticAFTFitter()
    save_model("LogLogistic", log_logistic)

    print("\t LogNormal")
    log_normal = get_LogNormalAFTFitter()
    save_model("LogNormal", log_normal)

    print("\t Cox_spline")
    cox = get_CoxPHFitter(baseline_estimation_method="spline", n_baseline_knots=2)
    save_model("Cox_spline", cox)
    cox.baseline_survival_.plot()
    pd.DataFrame(cox.params_).T
    plt.show()

    print("\t Cox_piecewise")
    cox = get_CoxPHFitter(
        baseline_estimation_method="piecewise",
        breakpoints=np.linspace(0, df_generator.max_horizon, resolution+1)[1:],
        penalizer=0.00
    )
    save_model("Cox_piecewise", cox)
    cox.baseline_survival_.plot()
    pd.DataFrame(cox.params_).T
    plt.show()

# Load data & train

In [None]:
dataset_names = ["gbsg2", "recur", "lymph", "california"]

In [None]:
for dataset_name in dataset_names:
    print(f"Starting lifelines trainings for {dataset_name}.")
    pathlib.Path(f'trained_models/{dataset_name}').mkdir(parents=True, exist_ok=True)

    df_generator = pickle.load(open(f"data_and_preprocessing/df_generator_{dataset_name}.pickle", "rb" ))
    dfs = df_generator(horizon=None)

    for part, df in dfs.items():
        print(part, df.shape)

    train_all()