In [1]:
import lightgbm as lgb
import numpy as np
from sksurv.metrics import concordance_index_censored
import pandas as pd
from scipy.stats import norm

Define the loss function of AFT model with log normal distribution assumption

In [16]:
def aft_ln_loss(preds, train_data):
    y_true = train_data.get_label()
    labels = np.abs(y_true)
    nsamp = len(y_true)
    yy = np.log(np.abs(y_true)) - preds
    indicator_event = np.array(y_true > 0)
    indicator_censor = np.array(y_true < 0)
    censor = yy[indicator_censor]
    dcensor = norm.pdf(censor)
    pcensor = norm.cdf(-censor)
    grad = np.ones(nsamp)
    hess = np.ones(nsamp)
    grad[indicator_event] = -yy[indicator_event]
    grad[indicator_censor] = -dcensor / pcensor
    hess[indicator_censor] = dcensor * (dcensor - censor * pcensor) / (pcensor ** 2)
    return grad, hess

Define the loss function of AFT model with exponential distribution assumption

In [17]:
def aft_exp_loss(preds, train_data):
    y_true = train_data.get_label()
    ey = np.exp(np.log(np.abs(y_true)) - preds)
    nsamp = len(y_true)
    indicator_event = np.array(y_true > 0)
    indicator_censor = np.array(y_true < 0)
    event = ey[indicator_event]
    censor = ey[indicator_censor]
    grad = np.ones(nsamp)
    hess = np.ones(nsamp)
    grad[indicator_event] = (1 - event) / (1 + event)
    grad[indicator_censor] = -1 / (1 + 1 / censor)
    hess[indicator_event] = 2 * event / ((1 + event) ** 2)
    hess[indicator_censor] = (1 + censor) ** (-2)
    return grad, hess

Define the loss function of AFT model with Weibull distribution assumption

In [19]:
def aft_weibull_loss(preds, train_data):
    y_true = train_data.get_label()
    ey = np.exp(np.log(np.abs(y_true)) - preds)
    indicator_event = np.array(y_true > 0)
    grad = -ey
    grad[indicator_event] = 1 - ey[indicator_event]
    return grad, ey

Define the evaluation metric with Concordance index

In [21]:
def evalc_aft(preds, train_data):
    y_true = train_data.get_label()
    return 'concordance_index', concordance_index_censored(np.array(y_true > 0), abs(y_true), -preds)[0], True

load sample training, validation and test data

In [22]:
xtr = pd.read_csv('xtr.csv')
xva = pd.read_csv('xva.csv')
xte = pd.read_csv('xte.csv')
ytr = pd.read_csv('ytr.csv')
yva = pd.read_csv('yva.csv')
yte = pd.read_csv('yte.csv')

Parameters can be customized

In [25]:
params = {
    'num_leaves': 31,
    'feature_fraction': 0.5,
    'bagging_fraction': 0.5,
    'bagging_freq': 20,
    'learning_rate': 0.05,
    'verbose': -1
}

Transfer training and validation data from data frame to LGB dataset

In [27]:
lgb_train = lgb.Dataset(xtr, ytr)
lgb_eval = lgb.Dataset(xva, yva, reference=lgb_train)

Train LGB AFT model with log normal, exponential and Weibull distributions. Using validation data for early stop. the number of interation and early stop steps can be customized

In [28]:
model_aftln = lgb.train(params, train_set = lgb_train, valid_sets = lgb_eval, num_boost_round = 1000, callbacks=[lgb.early_stopping(20)] ,fobj = aft_ln_loss, feval = evalc_aft)

Training until validation scores don't improve for 20 rounds
Early stopping, best iteration is:
[179]	valid_0's concordance_index: 0.829046


In [29]:
model_aftexp = lgb.train(params, train_set = lgb_train, valid_sets = lgb_eval, num_boost_round = 1000, callbacks=[lgb.early_stopping(20)] ,fobj = aft_exp_loss, feval = evalc_aft)

Training until validation scores don't improve for 20 rounds
Early stopping, best iteration is:
[40]	valid_0's concordance_index: 0.77013


In [30]:
model_aftwei = lgb.train(params, train_set = lgb_train, valid_sets = lgb_eval, num_boost_round = 1000, callbacks=[lgb.early_stopping(20)] ,fobj = aft_weibull_loss, feval = evalc_aft)

Training until validation scores don't improve for 20 rounds
Early stopping, best iteration is:
[152]	valid_0's concordance_index: 0.746645


Make prediction on test data

In [31]:
yyaftln = -model_aftln.predict(xte)
yyaftexp = -model_aftexp.predict(xte)
yyaftwei = -model_aftwei.predict(xte)

Compute concordance index and select the best model

In [32]:
cidxlist = {'log_normal': concordance_index_censored(np.array(yte > 0).reshape(200,), np.array(abs(yte)).reshape(200,), yyaftln)[0],
           'exponential': concordance_index_censored(np.array(yte > 0).reshape(200,), np.array(abs(yte)).reshape(200,), yyaftexp)[0],
           'weibull': concordance_index_censored(np.array(yte > 0).reshape(200,), np.array(abs(yte)).reshape(200,), yyaftwei)[0]}

In [33]:
cidxlist

{'log_normal': 0.8160516129032258,
 'exponential': 0.7418838709677419,
 'weibull': 0.7301161290322581}