In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import glob
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
from lifelines.fitters import BaseFitter
from lifelines.utils import concordance_index

In [None]:
from utils.metrics import AdvancedMetrics
from utils.get_tensors_of_df import get_tensors_of_df

# StatsClassifiers

In [None]:
class StatsClassifier:

    def __init__(self, name, model, horizon):
        self.name = name
        self.model = model
        self.horizon = horizon


    def _get_S(self, df):
        raise NotImplementedError()


    @torch.no_grad()
    def get_prediction(self, df):
        S = self._get_S(df)
        F = 1 - S
        return F


    def __call__(self, df_generator, part):
        assert part in {"train", "valid", "test"}, f"{part=}"

        advanced_metrics = AdvancedMetrics(
            stats_classifier=self,
            df_generator=df_generator,
            part=part,
        )
        return advanced_metrics()

In [None]:
class StatsClassifierLifelines(StatsClassifier):

    def _get_S_normal(self, df):
        columns = list(df.columns)
        columns.remove("duration")
        columns.remove("event_observed")
        df = df[columns].copy()

        S = self.model.predict_survival_function(df, times=self.horizon).values

        assert S.shape == (1, S.shape[1]), f"{self.name}-{self.horizon}: {S.shape=}"
        S = S[0]
        return S


    def _get_S_km(self, df):
        S = np.array([self.model.survival_function_at_times(self.horizon).item()] * len(df))
        return S


    def _get_S(self, df):
        if type(self.model) == KaplanMeierFitter:
            return self._get_S_km(df)
        return self._get_S_normal(df)

In [None]:
class StatsClassifierPyTorch(StatsClassifier):

    def _get_S(self, df):
        features, durations, event_observeds = get_tensors_of_df(df)
        self.model.eval()

        device = next(self.model.parameters()).device

        features = features.to(device)

        ts = torch.ones(len(features), 1, device=device) * self.horizon
        S = self.model(xs=features, ts=ts).detach().cpu().numpy()

        assert S.shape == (S.shape[0], 1), f"{self.name}-{self.horizon}: {S.shape=}"
        S = S[:,0]
        return S

# StatsSurvivalModels and how to load them

In [None]:
class StatsClassifiersGetter:

    def _get_horizons(stats_survival_models, resolution, include_zero):
        if resolution is None:
            return stats_survival_models.default_horizons

        horizons = np.linspace(0, self.max_horizon, resolution+2)[:-1]

        if include_zero:
            return horizons

        return horizons[1:]


    def get(stats_survival_models, resolution, include_zero):
        horizons = StatsClassifiersGetter._get_horizons(stats_survival_models, resolution, include_zero)

        model = stats_survival_models.model
        model_name = stats_survival_models.model_name

        is_torch_model = isinstance(model, torch.nn.Module)

        if is_torch_model:
            MyStatsClassifier = StatsClassifierPyTorch
        else:
            assert isinstance(model, BaseFitter), "Neither torch.nn.Module nor lifelines.fitters.BaseFitter."
            MyStatsClassifier = StatsClassifierLifelines

        return {horizon: MyStatsClassifier(model_name, model, horizon) for horizon in horizons}

In [None]:
class ComputeConcordance:

    def get_survival_curves_and_auc(stats_survival_models, df_generator, part):
        result_df = stats_survival_models.get_predctions_for_each_horizon(df_generator=df_generator, part=part).copy()
        auc = result_df.drop(columns=["duration", "event_observed"]).mean(axis=1)
        result_df["area under survival curve"] = auc
        return result_df


    def get_concordance(stats_survival_models, df_generator, part):
        result_df = ComputeConcordance.get_survival_curves_and_auc(stats_survival_models, df_generator, part)

        concordance = concordance_index(
            event_times=result_df.duration,
            predicted_scores=-result_df["area under survival curve"],
            event_observed=result_df.event_observed
        )
        return concordance

In [None]:
class StatsSurvivalModels:

    def __init__(self, model, model_name, max_horizon, resolution):
        self.model_name = model_name
        self.model = model
        self.max_horizon = max_horizon

        self.default_horizons = np.linspace(0, self.max_horizon, resolution+2)[1:-1]


    def get_stats_classifiers(self, resolution, include_zero):
        return StatsClassifiersGetter.get(self, resolution, include_zero)


    def get_predctions_for_each_horizon(self, df_generator, part, resolution=None, include_zero=False):
        dfs = df_generator(horizon=None)
        df = dfs[part]

        result_df = df[["duration", "event_observed"]].copy()

        horizons_stats_classifiers = self.get_stats_classifiers(resolution=resolution, include_zero=include_zero)

        for horizon, stats_classifier in horizons_stats_classifiers.items():
            F = stats_classifier.get_prediction(df)
            result_df[f"prediction_{horizon}"] = F

        return result_df


    def get_concordance(self, df_generator, part):
        return ComputeConcordance.get_concordance(self, df_generator, part)


    def __call__(self, df_generator, part, resolution=None):
        stats = {}

        horizons_stats_classifiers = self.get_stats_classifiers(resolution=resolution, include_zero=False)

        for horizon, stats_classifier in horizons_stats_classifiers.items():
            metrics = stats_classifier(df_generator=df_generator, part=part)
            stats[horizon] = metrics

        return pd.DataFrame(stats)

In [None]:
class LoadStatsSurvivalModels:

    def get_survival_models_file_names(self, data_name):
        file_paths = glob.glob(f"trained_models/{data_name}/*")
        names = [file_path.split("/")[-1][:-7] for file_path in file_paths]
        return names


    def load_survival_models(self, data_name):
        model_dicts = {}

        for name in self.get_survival_models_file_names(data_name):
            with open(f'trained_models/{data_name}/{name}.pickle', 'rb') as f:
                model_dict = pickle.load(f)
                model_dicts[name] = model_dict

        return model_dicts


    def turn_stats_classifiers_into_stats_survival_models(self, model_dicts, resolution):
        stats_survival_models = {}

        for model_name, model_dict in model_dicts.items():
            stats_survival_models[model_name] = StatsSurvivalModels(
                model=model_dict["model"],
                model_name=model_dict["name"],
                max_horizon=model_dict["max_horizon"],
                resolution=resolution,
            )

        return stats_survival_models


    def __call__(self, data_name, resolution):
        model_dicts = self.load_survival_models(data_name)
        stats_survival_models = self.turn_stats_classifiers_into_stats_survival_models(model_dicts, resolution)
        return stats_survival_models

# StatsSurvivalModelsProcessor

In [None]:
class StatsSurvivalModelsProcessorPlotter:

    def get_metric_names(stats_survival_models_processor):
        return list(stats_survival_models_processor.stats.keys())


    def plot_metric(stats_survival_models_processor, metric_name, ax, model_names):
        df = stats_survival_models_processor.stats[metric_name].copy()

        if model_names is None:
            model_names = df.columns

        df = df[model_names]

        new_column_names = {
            method_name: f"{method_name}: {mean}"
            for method_name, mean in stats_survival_models_processor.get_model_auc(metric_name).items()
        }
        df = df.rename(columns=new_column_names)
        df.plot(title=metric_name, ax=ax, ylim=(0, 1)).legend(loc='lower right', ncol=2)


    def plot_metrics(stats_survival_models_processor, model_names=None):
        metric_names = StatsSurvivalModelsProcessorPlotter.get_metric_names(stats_survival_models_processor)

        nrows = int(np.ceil(len(metric_names) / 2))
        fig, axes = plt.subplots(nrows=nrows, ncols=2)
        fig.set_size_inches(22, 5*nrows)
        fig.tight_layout()

        for i, metric_name in enumerate(metric_names):
            ax = axes[i//2, i%2]
            StatsSurvivalModelsProcessorPlotter.plot_metric(stats_survival_models_processor, metric_name, ax, model_names=model_names)

In [None]:
class StatsSurvivalModelsProcessor:

    def __init__(self, stats_survival_models, df_generator, part):
        self.df_generator = df_generator
        self.part = part

        self.stats_survival_models = stats_survival_models
        self.stats = self.get_each_metric_over_all_models()

        self.concordances = self._get_concordances()


    def _get_concordances(self):
        concordances = {}
        for model_name, ssm in self.stats_survival_models.items():
            concordance = ssm.get_concordance(self.df_generator, self.part)
            concordances[model_name] = concordance
        return concordances


    def get_full_curve_metrics(self):
        full_curve_metrics = {metric: df.mean() for metric, df in self.stats.items()}
        full_curve_metrics["Concordance"] = self.concordances
        df = pd.DataFrame(full_curve_metrics)

        df["Mean"] = df.mean(axis=1)
        df["Min"] = df.min(axis=1)
        return df


    def print_full_curve_metrics(self, df=None):
        if df is None:
            df = self.get_full_curve_metrics()
        df_styler = df.style.background_gradient(cmap='RdYlGn')#df.style.highlight_max(color='yellow', axis=0)
        df_styler.format(precision=2)
        display(df_styler)


    def _get_stats_of_each_model_as_dfs(self):
        stats = {}
        for model_name, stats_survival_model in self.stats_survival_models.items():
            stats[model_name] = stats_survival_model(
                df_generator=self.df_generator,
                part=self.part,
            )
        return stats


    def _get_stats_names(self, stats):
        keys = list(stats.keys())
        stats_names = list(stats[keys[0]].index)
        return stats_names


    def get_each_metric_over_all_models(self):
        stats = self._get_stats_of_each_model_as_dfs()
        stats_reverse = {}

        for stats_name in self._get_stats_names(stats):
            single_stats = {}
            for model_name, stats_survival_model in self.stats_survival_models.items():
                single_stats[model_name] = {}
                for horizon in stats_survival_model.default_horizons:
                    single_stats[model_name][horizon] = stats[model_name][horizon][stats_name]

            stats_reverse[stats_name] = pd.DataFrame(single_stats)
        return stats_reverse


    def get_model_auc(self, metric_name):
        df = self.stats[metric_name]
        return dict(zip(df.mean().keys(), df.mean().round(2).values))


    def plot_metrics(self, model_names=None):
        StatsSurvivalModelsProcessorPlotter.plot_metrics(self, model_names=model_names)