In [None]:
#default_exp model

# Base Model
> This class contains the base which is used to train data upon.

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
import torch
import torch.nn as nn
from torchlife.models.km import KaplanMeier
from torchlife.models.ph import PieceWiseHazard
from torchlife.models.cox import ProportionalHazard

from torchlife.data import create_db, create_test_dl

from torchlife.losses import *

from fastai.basics import Learner

In [None]:
# export
_text2model_ = {
    'km': KaplanMeier,
    'ph': PieceWiseHazard,
    'cox': ProportionalHazard
}

_text2loss_ = {
    'km': hazard_loss,
    'ph': hazard_loss,
    'cox': hazard_loss
}

In [None]:
# export
class Model:
    def __init__(self, model:str, model_args:dict=None, breakpoints:list=None, 
                 bs:int=128, epochs:int=20, lr:float=1, beta:float=0):
        self.model = _text2model_[model](**model_args)
        self.loss = _text2loss_[model]
        self.breakpoints = breakpoints
        self.bs, self.epochs, self.lr, self.beta = bs, epochs, lr, beta
        self.learner = None
        
    def lr_find(self, df):
        db = create_db(df, self.breakpoints)
        self.learner = Learner(db, self.model, loss_func=self.loss, wd=self.beta)
        self.learner.lr_find(wd=self.beta)
        self.learner.recorder.plot()
        
    def fit(self, df):
        if hasattr(self.model, 'fit'):
            self.model.fit(df)
        else:
            if self.learner is None:
                db = create_db(df, self.breakpoints)
                self.learner = Learner(db, self.model, loss_func=self.loss, wd=self.beta)
            self.learner.fit(self.epochs, lr=self.lr, wd=self.beta)
            
    def predict(self, df):
        test_dl = create_test_dl(df, self.breakpoints)
        with torch.no_grad():
            self.model.eval()
            λ, Λ = [], []
            for x in test_dl:
                preds = self.model(*x)
                λ.append(torch.exp(preds[0]))
                Λ.append(preds[1])
            return torch.cat(λ), torch.cat(Λ)
        
    def plot_survival(self):
        self.model.plot_survival_function()

In [None]:
# hide
import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)

In [None]:
print(df.shape)
df.head()

(432, 9)


Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


In [None]:
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
event_times = df.loc[df['e']==1, 't'].values
breakpoints = np.percentile(event_times, [20, 40, 60, 80])
max_t = df['t'].max()
dim = df.shape[1] - 2
args = {'breakpoints': breakpoints, 'max_t': max_t, 'dim': dim}
model = Model('cox', args, breakpoints)

In [None]:
model.fit(df)

epoch,train_loss,valid_loss,time
0,4137.632812,593.761292,00:00
1,2169.122803,88.810028,00:00
2,1421.311523,24.565239,00:00
3,1037.651733,10.273753,00:00
4,805.983765,5.843089,00:00
5,651.403809,4.139925,00:00
6,541.172974,3.382817,00:00
7,458.76059,3.009356,00:00
8,394.930756,2.810371,00:00
9,344.123444,2.697758,00:00


In [None]:
λ, Λ = model.predict(df)
λ.shape, Λ.shape

(torch.Size([432, 1]), torch.Size([432, 1]))

In [None]:
# hide
from nbdev.export import *
notebook2script("model.ipynb")

Converted model.ipynb.
