In [None]:
#default_exp model

# Base Model
> This class contains the base which is used to train data upon.

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
import torch
import torch.nn as nn
from torchlife.models.ph import PieceWiseHazard
from torchlife.models.cox import ProportionalHazard
from torchlife.models.aft import AFTModel

from torchlife.data import create_db, create_test_dl, get_breakpoints

from torchlife.losses import *

from fastai.basics import Learner

## Modelling Hazard

In [None]:
# export
_text2model_ = {
    'ph': PieceWiseHazard,
    'cox': ProportionalHazard
}

In [None]:
dim = 0


In [None]:
# export
class ModelHazard:
    """
    Modelling instantaneous hazard (λ).
    parameters:
    - model(str): ['ph'|'cox'] which maps to Piecewise Hazard, Cox Proportional Hazard.
    - percentiles: list of time percentiles at which time should be broken
    - h: list of hidden units (disregarding input units)
    - bs: batch size
    - epochs: epochs
    - lr: learning rate
    - beta: l2 penalty on weights
    """
    def __init__(self, model:str, percentiles=[20, 40, 60, 80], h:tuple=(),
                 bs:int=128, epochs:int=20, lr:float=1.0, beta:float=0):
        self.model = _text2model_[model]
        self.percentiles = percentiles
        self.loss = hazard_loss
        self.h = h
        self.bs, self.epochs, self.lr, self.beta = bs, epochs, lr, beta
        self.learner = None
        
    def create_learner(self, df):
        breakpoints = get_breakpoints(df, self.percentiles)
        db, t_scaler, x_scaler = create_db(df, breakpoints)
        dim = df.shape[1] - 2
        assert dim > 0, ValueError("dimensions of x input needs to be >0. Choose ph instead")

        model_args = {
            'breakpoints': breakpoints, 
            't_scaler': t_scaler, 
            'x_scaler': x_scaler, 
            'h': self.h, 
            'dim': dim
        }
        self.model = self.model(**model_args)
        self.learner = Learner(db, self.model, loss_func=self.loss, wd=self.beta)
        
        self.breakpoints = breakpoints
        self.t_scaler = t_scaler
        self.x_scaler = x_scaler
        
    def lr_find(self, df):
        if self.learner is None:
            self.create_learner(df)
        
        self.learner.lr_find(wd=self.beta)
        self.learner.recorder.plot()
        
    def fit(self, df):
        if self.learner is None:
            self.create_learner(df)
        self.learner.fit(self.epochs, lr=self.lr, wd=self.beta)
            
    def predict(self, df):
        test_dl = create_test_dl(df, self.breakpoints, self.t_scaler, self.x_scaler)
        with torch.no_grad():
            self.model.eval()
            λ, S = [], []
            for x in test_dl:
                preds = self.model(*x)
                λ.append(torch.exp(preds[0]))
                S.append(torch.exp(-preds[1]))
            return torch.cat(λ), torch.cat(S)
        
    def plot_survival_function(self, *args):
        self.model.plot_survival_function(*args)

In [None]:
# hide
import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)

In [None]:
print(df.shape)
df.head()

(432, 9)


Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


In [None]:
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
model = ModelHazard('cox')

In [None]:
model.fit(df)

epoch,train_loss,valid_loss,time
0,7.833423,0.967097,00:00
1,5.683983,5.419213,00:00
2,4.220398,2.326273,00:00
3,3.532574,2.975143,00:00
4,3.170237,2.695517,00:00
5,2.902388,2.413046,00:00
6,2.703852,2.36825,00:00
7,2.566355,2.401003,00:00
8,2.46817,2.42138,00:00
9,2.380688,2.414237,00:00


In [None]:
λ, Λ = model.predict(df)
df.shape, λ.shape, Λ.shape

((432, 9), torch.Size([432, 1]), torch.Size([432, 1]))

## Modelling Distribution with [AFT](./AFT_models) models

In [None]:
# export
from torchlife.models.error_dist import *

class ModelAFT:
    """
    Modelling error distribution given inputs x.
    parameters:
    - dist(str): Univariate distribution of error
    - h: list of hidden units (disregarding input units)
    - bs: batch size
    - epochs: epochs
    - lr: learning rate
    - beta: l2 penalty on weights
    """
    def __init__(self, dist:str, h:tuple=(),
                 bs:int=128, epochs:int=20, lr:float=1, beta:float=0):
        self.dist = dist
        self.loss = aft_loss
        self.h = h
        self.bs, self.epochs, self.lr, self.beta = bs, epochs, lr, beta
        self.learner = None
        
    def create_learner(self, df):
        dim = df.shape[1] - 2
        db = create_db(df)
        self.model = AFTModel(self.dist, dim, self.h)
        self.learner = Learner(db, self.model, loss_func=self.loss, wd=self.beta)
        
    def lr_find(self, df):
        if self.learner is None:
            self.create_learner(df)
        
        self.learner.lr_find(wd=self.beta)
        self.learner.recorder.plot()
        
    def fit(self, df):
        if self.learner is None:
            self.create_learner(df)
        self.learner.fit(self.epochs, lr=self.lr, wd=self.beta)
            
    def predict(self, df):
        test_dl = create_test_dl(df)
        with torch.no_grad():
            self.model.eval()
            Λ = []
            for x in test_dl:
                _, logΛ = self.model(*x)
                Λ.append(torch.log(logΛ))
            return torch.cat(Λ)
        
    def plot_survival(self, *args):
        self.model.plot_survival_function(*args)

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_index.ipynb.
Converted 10_SAT.ipynb.
Converted 20_KaplanMeier.ipynb.
Converted 50_hazard.ipynb.
Converted 55_hazard.PiecewiseHazard.ipynb.
Converted 59_hazard.Cox.ipynb.
Converted 60_AFT_models.ipynb.
Converted 65_AFT_error_distributions.ipynb.
Converted 80_data.ipynb.
Converted 90_model.ipynb.
Converted 95_Losses.ipynb.
