In [None]:
import os
import nbimporter

root = os.getcwd().split("survival_analysis")[0]
os.chdir(root + "survival_analysis")

In [None]:
import copy
import torch
import pickle
import numpy as np
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

In [None]:
from utils.get_tensors_of_df import get_tensors_of_df

# SampleTimesAndLabels

In [None]:
class SampleTimesAndLabels:

    def __init__(self):
        pass


    def _sample_times_and_labels(self, durations, event_observeds, horizon):
        raise NotImplementedError()


    def __call__(self, durations, event_observeds, horizon):
        return self._sample_times_and_labels(durations, event_observeds, horizon)

In [None]:
class SampleTimesAndLabelsDelta(SampleTimesAndLabels):

    def _sample_times_and_labels(self, durations, event_observeds, horizon):
        ts = durations.clone().view(len(durations), 1)
        alives = 1 - event_observeds.clone().type(ts.dtype)

        assert ts.shape == (len(durations), 1), f"{ts.shape=} has a bad shape."
        assert durations.shape == ts.shape, f"Different shapes. {durations.shape=} != {ts.shape=}"
        assert alives.shape == ts.shape, f"Different shapes. {alives.shape=} != {ts.shape=}"
        return ts, alives

In [None]:
class SampleTimesAndLabelsGaussianDelta(SampleTimesAndLabels):

    def __init__(self, σ):
        self._σ = σ


    def _sample_times_and_labels(self, durations, event_observeds, horizon):
        η = torch.randn(len(durations), 1) * self._σ
        ts = durations + η

        ts[~event_observeds] = durations[~event_observeds]
        ts = ts.clamp(0)

        alives = (ts < durations) | ~event_observeds

        assert durations.shape == ts.shape
        return ts, alives.type(ts.dtype)

# Criterions

In [None]:
class WeightedBCELoss:

    def __init__(self, σ_gaussian_delta, weight=1, label_smoothing=1e-3):
        super().__init__()
        self.my_weight = weight
        self.label_smoothing = label_smoothing

        self.sample_times_and_labels = SampleTimesAndLabelsGaussianDelta(σ=σ_gaussian_delta)

        assert 0 <= self.my_weight <= 1, f"{self.my_weight=} is not in the interval [0, 1]."


    def _get_weights(self, targets):
        targets_binary = targets > .5
        weight = torch.ones_like(targets) * (1 - self.my_weight)
        weight[~targets_binary] = self.my_weight
        assert torch.all((0 <= weight) & (weight <= 1)), "BCE weights outside [0, 1]."
        return weight


    def __call__(self, outputs, targets, ts):
        targets = torch.clamp(targets, self.label_smoothing, 1 - self.label_smoothing)
        outputs.data.clamp_(self.label_smoothing, 1 - self.label_smoothing)

        weight = self._get_weights(targets)

        loss = torch.nn.functional.binary_cross_entropy(
            outputs,
            targets,
            weight=weight,
            size_average=None,
            reduce=None,
            reduction='mean'
        )

        assert torch.all(-1e-2 <= outputs), f"Too small values in outputs. {outputs[0 > outputs]=}"
        assert torch.all(outputs <= 1), "Too large values in outputs."
        return loss

In [None]:
class SuMoLoss:

    def __init__(self, weight=1, label_smoothing=1e-3):
        super().__init__()
        self.my_weight = weight
        self.label_smoothing = label_smoothing

        self.sample_times_and_labels = SampleTimesAndLabelsDelta()


    def _get_δS(self, S_t, ts):
        grads = torch.autograd.grad(
            outputs=S_t,
            inputs=ts,
            grad_outputs=torch.ones_like(S_t),
            create_graph=True,
            retain_graph=True
        )[0]
        return grads


    def _get_f(self, S_t, ts):
        grads = self._get_δS(S_t, ts)

        f = -grads

        f[torch.isnan(f) & (S_t < 1e-3)] = 0.

        assert grads.shape == S_t.shape, f"Shapes don't match: {grads.shape=} {S_t.shape=}"
        assert not torch.any(torch.isnan(f)), f"f has NaNs, {f[torch.isnan(f)]=}, {S_t[torch.isnan(f)]=}, {ts[torch.isnan(f)]=}"
        assert torch.all(f >= -0.1), f"f is negative. {f[f < 0]=}, {S_t[f < 0]=}"
        return f


    def _get_f_ll(self, f, alives, ε=1e-16):
        alives_binary = alives > 0.5
        f_ll = torch.log(f[~alives_binary].clamp(ε)).clamp(-10).sum() * self.my_weight
        return f_ll


    def _get_S_ll(self, S, alives, ε=1e-16):
        alives_binary = alives > 0.5
        S_ll = torch.log(S[alives_binary].clamp(ε)).clamp(-10).sum()
        return S_ll


    def __call__(self, outputs, alives, ts):
        f = self._get_f(outputs, ts)
        S = outputs

        f_ll = self._get_f_ll(f, alives)
        S_ll = self._get_S_ll(f, alives)

        loss = -(f_ll + S_ll) / len(outputs)

        assert not torch.any(torch.isnan(S)), f"Found NaN in outputs. {f[torch.isnan(S)]=}, {S_t[torch.isnan(S)]=}, {ts[torch.isnan(S)]=}"
        assert torch.all(-1e-2 <= S), f"Too small values in outputs. {f[S_t < 0]=}, {S_t[S_t < 0]=}"
        assert torch.all(S <= 1), "Too large values in outputs."
        return loss

# TorchDfHandler

In [None]:
class TorchDfHandler:

    def __init__(self, df_generator, part, sample_times_and_labels, adjust, horizon=None):
        if horizon is None:
            assert not adjust, f"{horizon=}, but {adjust=}."

        self._df_generator = df_generator
        self._part = part
        self._horizon = df_generator.max_horizon if horizon is None else horizon

        dfs = df_generator(horizon=horizon, adjust=adjust)
        df = dfs[part].copy()

        self._df = self._check_and_clean_df(df)
        self._df_generator.check_that_column_names_are_valid(self._df)
        self._features, self._durations, self._event_observeds = self._get_tensors_of_df()

        self.sample_times_and_labels = sample_times_and_labels


    @property
    def df_generator(self):
        return self._df_generator


    @property
    def horizon(self):
        return self._horizon


    @property
    def df(self):
        return self._df


    def _check_and_clean_df(self, df):
        columns = list(df.columns)
        df = df[columns]

        assert "duration" in columns, f"No 'duration' in {columns=}"
        assert "event_observed" in columns, f"No 'event_observed' in {columns=}"
        return df


    def _get_tensors_of_df(self):
        return get_tensors_of_df(df=self._df)


    @property
    def n_input_features(self):
        assert len(self._features.shape) == 2
        return self._features.shape[1]


    def sample_batch(self, n_samples):
        features, durations, event_observeds = self()

        idxs = np.random.randint(0, len(features), size=n_samples)

        features = features[idxs]
        durations = durations[idxs]
        event_observeds = event_observeds[idxs]
        return features, durations, event_observeds


    def __call__(self):
        return self._features, self._durations, self._event_observeds

# Training

In [None]:
def moving_average(a, n):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

## BestModelMemorizer

In [None]:
class BestModelMemorizer:

    def __init__(self, model_via_moving_average_on_validation, patience_factor):
        self.model_via_moving_average_on_validation = model_via_moving_average_on_validation

        self._best_model = None
        self._selected_epoch = np.inf
        self._best_val_score = np.inf
        self._likely_best_found = False
        self._patience_factor = patience_factor


    def __call__(self):
        return self._best_model, self._selected_epoch, self._best_val_score


    def log(self, losses_valid, model):
        no_best_model_selection = self.model_via_moving_average_on_validation is None

        if no_best_model_selection:
            self._best_model = copy.deepcopy(model)
            self._selected_epoch = len(losses_valid)
            self._best_val_score = np.mean(losses_valid)
            return

        warm_up_phase = self.model_via_moving_average_on_validation > len(losses_valid)

        if warm_up_phase:
            self._best_model = copy.deepcopy(model)
            self._selected_epoch = len(losses_valid)
            self._best_val_score = np.inf
            return

        mova_val_loss = moving_average(losses_valid, n=self.model_via_moving_average_on_validation)[-1]

        if mova_val_loss < self._best_val_score:
            self._best_val_score = mova_val_loss
            self._best_model = copy.deepcopy(model)
            self._selected_epoch = len(losses_valid)
            return

        patience = self.model_via_moving_average_on_validation * self._patience_factor
        ran_out_of_patience = self._selected_epoch < len(losses_valid) - patience

        if ran_out_of_patience:
            self._likely_best_found = True


    @property
    def likely_best_found(self):
        return self._likely_best_found

## TrainingPlotting

In [None]:
class TrainingPlotting:

    def plot_losses(losses_train, losses_valid, n_mova):
        plt.plot(
            range(n_mova-1, len(losses_train)),
            moving_average(losses_train, n=n_mova),
            label="train mova"
        )

        plt.plot(
            range(n_mova-1, len(losses_valid)),
            moving_average(losses_valid, n=n_mova),
            label="valid mova"
        )

        plt.title(moving_average(losses_valid, n=n_mova).min())
        plt.legend()
        plt.show()


    def plot_some_survival_curves(model, df_handler_valid):
        horizon = df_handler_valid.horizon

        for i in range(3 * 6):
            plt.subplot(3, 6, i + 1)

            device = next(model.parameters()).device
            model.eval()

            features, durations, event_observeds = df_handler_valid.sample_batch(1)
            multi_features = torch.cat([features]*64, dim=0).to(device)

            ts = torch.linspace(0, horizon, 64, device=device).view(-1, 1)
            S_ts = model(xs=multi_features, ts=ts)

            plt.plot(ts.flatten().detach().cpu().numpy(), S_ts.flatten().detach().cpu().numpy())

            if event_observeds.item():
                plt.vlines(durations.item(), 0, 1, 'r')

            if not event_observeds.item():
                plt.vlines(durations.item(), 0, 1, 'b')

            plt.ylim(0, 1)

        model.train()
        plt.gcf().set_size_inches(35, 10)
        plt.show()

## NetTrainer

In [None]:
class NetTrainer:

    def __init__(
        self,
        df_generator,
        criterion,
        n_training_steps=200000,
        batch_size=32,
        lr=1e-3,
        weight_decay=0,
        model_via_moving_average_on_validation=512,
        patience_factor=5,
        adjust=False,
        clip=None,
    ):
        assert model_via_moving_average_on_validation is not None, "model_via_moving_average_on_validation is None"

        self.horizon = df_generator.max_horizon
        self.df_generator = df_generator
        self.adjust = adjust
        self.clip = clip

        self.criterion = criterion
        self.weight_decay = weight_decay

        self.lr = lr
        self.batch_size = batch_size
        self.n_training_steps = n_training_steps
        self._patience_factor = patience_factor

        self.model_via_moving_average_on_validation = model_via_moving_average_on_validation

        self.df_handler_valid = self._get_df_handler("valid")


    @property
    def n_input_features(self):
        return self.df_handler_valid.n_input_features


    def _get_df_handler(self, part, horizon=None):
        adjust = False if horizon is None else self.adjust

        handler = TorchDfHandler(
            df_generator=self.df_generator,
            part=part,
            sample_times_and_labels=self.criterion.sample_times_and_labels,
            adjust=adjust,
            horizon=horizon,
        )
        return handler


    def get_loss_on_random_batch(self, model, df_handler):
        features, durations, event_observeds = df_handler.sample_batch(self.batch_size)
        ts, alives = df_handler.sample_times_and_labels(durations, event_observeds, horizon=self.horizon)

        preds = model(xs=features, ts=ts)
        loss = self.criterion(preds, alives, ts=ts)
        return loss


    def _eval_current_model(self, model, losses_valid, best_model_memorizer, df_handler_valid):
        model.eval()
        loss = self.get_loss_on_random_batch(model=model, df_handler=df_handler_valid)
        losses_valid.append(loss.item())
        best_model_memorizer.log(losses_valid=losses_valid, model=model)
        model.train()


    def _get_best_model_memorizer(self):
        best_model_memorizer = BestModelMemorizer(
            model_via_moving_average_on_validation=self.model_via_moving_average_on_validation,
            patience_factor=self._patience_factor
        )
        return best_model_memorizer


    def _train_step(self, optimizer, loss, model):
        optimizer.zero_grad()
        loss.backward()
        if self.clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip)
        optimizer.step()


    def train_model(self, model):
        best_model_memorizer = self._get_best_model_memorizer()

        df_handler_train = self._get_df_handler(part="train")
        df_handler_valid = self._get_df_handler(part="valid")

        losses_train, losses_valid = [], []
        optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        model.train()

        for _ in tqdm(range(self.n_training_steps)):
            loss = self.get_loss_on_random_batch(model=model, df_handler=df_handler_train)
            losses_train.append(loss.item())

            self._train_step(optimizer, loss, model)
            self._eval_current_model(model, losses_valid, best_model_memorizer, df_handler_valid)

            if best_model_memorizer.likely_best_found:
                print(f"Stoped training at step {len(losses_valid)}.")
                break

        model.eval()
        model, selected_epoch, best_val_score = best_model_memorizer()
        return model, losses_train, losses_valid, selected_epoch, best_val_score


    def _get_model_dict(self, name, model):
        model_dict = {
            "model": model,
            "name": name,
            "max_horizon": self.horizon,
        }
        return model_dict


    def _save(self, name, data):
        with open(f'trained_models/{self.df_handler_valid.df_generator.name}/{name}.pickle', 'wb') as f:
            pickle.dump(data, f)


    def train_and_save(self, name, model, verbose=True, save=True):
        model, losses_train, losses_valid, selected_epoch, best_val_score = self.train_model(model)

        if verbose:
            print(f"{selected_epoch=}")
            TrainingPlotting.plot_losses(losses_train, losses_valid, n_mova=self.model_via_moving_average_on_validation)
            TrainingPlotting.plot_some_survival_curves(model, self.df_handler_valid)

        model_dict = self._get_model_dict(name, model)
        if save:
            self._save(name, model_dict)
        return model, best_val_score, model_dict