In [None]:
# default_exp models.aft

# Accelerated Failure Time Models
> AFT Model theory.

We can model the time to failure as:
$$
\log T_i = \mu + \xi_i
$$
where $\xi_i\sim p(\xi|\theta)$ and $\mu$ is the most likely log time of death (the mode of the distribution of $T_i$). We model log death as that way we do not need to restrict $\mu + \xi_i$ to be positive.

In the censored case, where $t_i$ is the time where an instance was censored, and $T_i$ is the unobserved time of death, we have:
$$
\begin{aligned}
\log T_i &= \mu(x_i) + \xi_i > \log t_i\\
\therefore \xi_i &> \log t_i - \mu(x_i)
\end{aligned}
$$
Note that $\mu$ is a function of the features $x$. The log likelihood of the data ($\mathcal{D}$) can then shown to be:
$$
\begin{aligned}
\log p(\mathcal{D}) = \sum_{i=1}^N \mathcal{1}(y_i=1)\log p(\xi_i = \log t_i - \mu(x_i)) + \mathcal{1}(y_i=0)\log p(\xi_i &> \log t_i - \mu(x_i))
\end{aligned}
$$

In [None]:
#export
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import MaxAbsScaler, StandardScaler

from torchlife.models.error_dist import get_distribution

In [None]:
# hide
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
# export
class AFTModel(nn.Module):
    """
    Accelerated Failure Time model
    parameters:
    - Distribution of which the error is assumed to be
    - dim (optional): input dimensionality of variables
    - h (optional): number of hidden nodes
    """
    def __init__(self, distribution:str, input_dim:int, h:tuple=()):
        super().__init__()
        self.logpdf, self.logicdf = get_distribution(distribution)
        self.β = nn.Parameter(-torch.rand(1))
        self.logσ = nn.Parameter(-torch.rand(1))
        
        if input_dim > 0:
            nodes = (input_dim,) + h + (1,)
            self.layers = nn.ModuleList([nn.Linear(a,b, bias=False) 
                                       for a,b in zip(nodes[:-1], nodes[1:])])

        self.eps = 1e-7

    def get_mode_time(self, x:torch.Tensor=None):
        μ = self.β
        if x is not None:
            for layer in self.layers[:-1]:
                x = F.relu(layer(x))
            μ = self.β + self.layers[-1](x)

        σ = torch.exp(self.logσ)
        return μ, σ
    
    def forward(self, t:torch.Tensor, x:torch.Tensor=None):
        μ, σ = self.get_mode_time(x)
        ξ = torch.log(t + self.eps) - μ
        logpdf = self.logpdf(ξ, σ)
        logicdf = self.logicdf(ξ, σ)
        return logpdf, logicdf
    
    def survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):
        if len(t.shape) == 1:
            t = t[:,None]
        t = t_scaler.transform(t)
        t = torch.Tensor(t)
        if x is not None:
            if len(x.shape) == 1:
                x = x[None, :]
            if len(x) == 1:
                x = np.repeat(x, len(t), axis=0)
            x = x_scaler.transform(x)
            x = torch.Tensor(x)
        
        with torch.no_grad():
            # calculate cumulative hazard according to above
            _, Λ = self(t, x)
            return torch.exp(Λ)
            
    def plot_survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):
        surv_fun = self.survival_function(t, t_scaler, x, x_scaler)
        
        # plot
        plt.figure(figsize=(12,5))
        plt.plot(t, surv_fun)
        plt.xlabel('Time')
        plt.ylabel('Survival Probability')
        plt.show()

Modelling based on **only** time and (death) event variables:

In [None]:
# from torchlife.data import create_dl
# import pandas as pd

# url = "https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/survival/flchain.csv"
# df = pd.read_csv(url).iloc[:,1:]
# df.rename(columns={'futime':'t', 'death':'e'}, inplace=True)

# cols = ["age", "sample.yr", "kappa"]
# db, t_scaler, x_scaler = create_dl(df[['t', 'e'] + cols])

# death_rate = 100*df["e"].mean()
# print(f"Death occurs in {death_rate:.2f}% of cases")
# print(df.shape)
# df.head()

In [None]:
# # hide
# from fastai.basics import Learner
# from torchlife.losses import aft_loss

# model = AFTModel("Gumbel", t_scaler, x_scaler)
# learner = Learner(db, model, loss_func=aft_loss)
# # wd = 1e-4
# learner.lr_find(start_lr=1, end_lr=10)
# learner.recorder.plot()

In [None]:
# learner.fit(epochs=10, lr=2)

In [None]:
# model.plot_survival_function(np.linspace(0, df["t"].max(), 100), df.loc[0, cols])

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_index.ipynb.
Converted 10_SAT.ipynb.
Converted 20_KaplanMeier.ipynb.
Converted 30_overall_model.ipynb.
Converted 50_hazard.ipynb.
Converted 55_hazard.PiecewiseHazard.ipynb.
Converted 59_hazard.Cox.ipynb.
Converted 60_AFT_models.ipynb.
Converted 65_AFT_error_distributions.ipynb.
Converted 80_data.ipynb.
Converted 90_model.ipynb.
Converted 95_Losses.ipynb.
