<a href="https://colab.research.google.com/github/yuvrajiro/yuvrajiro/blob/master/Random_Survival_Forest2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install lifelines

Collecting lifelines
[?25l  Downloading https://files.pythonhosted.org/packages/50/ba/d010b22c8bcdfe3bbba753bd976f5deddfa4ec1c842b991579e9c2c3cd61/lifelines-0.26.0-py3-none-any.whl (348kB)
[K     |████████████████████████████████| 358kB 5.1MB/s 
[?25hCollecting autograd-gamma>=0.3
  Downloading https://files.pythonhosted.org/packages/85/ae/7f2031ea76140444b2453fa139041e5afd4a09fc5300cfefeb1103291f80/autograd-gamma-0.5.0.tar.gz
Collecting formulaic<0.3,>=0.2.2
[?25l  Downloading https://files.pythonhosted.org/packages/02/64/6702b5cadc89ece93af2e01996504f3a895196354a35713e2ef22f089d3e/formulaic-0.2.3-py3-none-any.whl (55kB)
[K     |████████████████████████████████| 61kB 7.3MB/s 
Collecting interface-meta>=1.2
  Downloading https://files.pythonhosted.org/packages/71/31/5e474208f5df9012ebecfaa23884b14f93671ea4f4f6d468eb096b73e499/interface_meta-1.2.3-py2.py3-none-any.whl
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?

In [261]:
import pandas as pd
import numpy as np
from lifelines import NelsonAalenFitter
from lifelines.statistics import logrank_test
from itertools import combinations
from joblib import Parallel, delayed
import multiprocessing
from tqdm.notebook import tqdm as tqdm


## Node

In [262]:
class Node:

    def __init__(self, x, y, tree, unique_deaths=3, min_leaf=3, timeline=None):

        self.x = x
        self.y = y
        self.tree = tree
        self.f_idxs = np.random.permutation(np.shape(x)[1])[np.arange(np.int(np.ceil(np.sqrt(np.shape(x)[1]))))]
        self.unique_deaths = unique_deaths
        self.min_leaf = min_leaf
        self.score = 0
        self.split_val = None
        self.split_var = None
        self.lhs = None
        self.rhs = None
        self.chf = None
        self.terminal = False
        self.timeline = timeline
        self.grow_tree()

    def grow_tree(self):

        unique_deaths = np.sum(np.unique(self.y , axis = 0) , axis = 0)[1]

        if unique_deaths <= self.unique_deaths:
            self.compute_terminal_node()
            return self

        self.score, self.split_val, self.split_var, lhs_idxs_opt, rhs_idxs_opt = find_split(self)

        if self.split_var is None:
            self.compute_terminal_node()
            return self

        self.lhs = Node(self.x[lhs_idxs_opt, :], self.y[lhs_idxs_opt, :], self.tree,
                         min_leaf=self.min_leaf, timeline=self.timeline)

        self.rhs = Node(self.x[rhs_idxs_opt, :], self.y[rhs_idxs_opt, :], self.tree,
                         min_leaf=self.min_leaf, timeline=self.timeline)

        return self

    def compute_terminal_node(self):

        self.terminal = True
        self.chf = NelsonAalenFitter()
        t = self.y[:, 0]    # Time of The Event
        e = self.y[:, 1]    # Indicator for occurrence of event
        self.chf.fit(t, event_observed=e, timeline=self.timeline)
        return self

    def predict(self, x):

        if self.terminal:
            self.tree.chf = self.chf.cumulative_hazard_.to_numpy()
            self.tree.chf = self.tree.chf[:, 0]
            return self.tree.chf
        else:
            if x[self.split_var] <= self.split_val:
                self.lhs.predict(x)
            else:
                self.rhs.predict(x)


## Functions for Splitting





In [263]:
def find_split(node):

    score_opt = -5000
    split_val_opt = None
    lhs_idxs_opt = None
    rhs_idxs_opt = None
    split_var_opt = None
    for i in node.f_idxs:
        score, split_val, lhs_idxs, rhs_idxs = logrank_statistics(node.x[:,i], node.y , node.min_leaf)
        if score > score_opt:
            score_opt = score
            split_val_opt = split_val
            lhs_idxs_opt = lhs_idxs
            rhs_idxs_opt = rhs_idxs
            split_var_opt = i

    return score_opt, split_val_opt, split_var_opt, lhs_idxs_opt, rhs_idxs_opt

def logrank_statistics(x_var, y, min_leaf):

    score_opt = -5000
    split_val_opt = None
    lhs_idxs = None
    rhs_idxs = None
    thresholds = np.unique(np.sort(x_var))

    for split_val in thresholds:
        
        feature1 = x_var <= split_val       #Creating an array of True , False which works as index
        feature2 = ~feature1
        if np.sum(feature1) < min_leaf or np.sum(feature2) < min_leaf:
            continue
        score = logrank_test(y[feature1, 0],y[feature2, 0] ,y[feature1, 1] , y[feature2, 1]).test_statistic
        if score > score_opt:
            score_opt = round(score, 3)
            split_val_opt = round(split_val, 3)
            lhs_idxs = feature1
            rhs_idxs = feature2

    return score_opt, split_val_opt, lhs_idxs, rhs_idxs


## Survival Tree

In [264]:
class SurvivalTree:

    def __init__(self, x, y,unique_deaths=3, min_leaf=3, timeline=None):

        self.x = x
        self.y = y
        self.f_idxs = np.random.permutation(np.shape(x)[1])[np.arange(np.int(np.ceil(np.sqrt(np.shape(x)[1]))))]
        self.min_leaf = min_leaf
        self.unique_deaths = unique_deaths
        self.score = 0
        self.index = 0
        self.split_val = None
        self.split_var = None
        self.lhs = None
        self.rhs = None
        self.chf = None
        self.prediction_possible = None
        self.timeline = timeline
        self.grow_tree()

    def grow_tree(self):

        unique_deaths = np.sum(np.unique(self.y , axis = 0) , axis = 0)[1]

        self.score, self.split_val, self.split_var, lhs_idxs_opt, rhs_idxs_opt = find_split(self)

        if self.split_var is not None and unique_deaths > self.unique_deaths:
            self.prediction_possible = True


            self.lhs = Node(x=self.x[lhs_idxs_opt, :], y=self.y[lhs_idxs_opt, :],
                            tree = self ,unique_deaths=self.unique_deaths, 
                            min_leaf=self.min_leaf,timeline=self.timeline)

            self.rhs = Node(x=self.x[rhs_idxs_opt, :], y=self.y[rhs_idxs_opt, :],
                            tree = self , unique_deaths=self.unique_deaths, 
                            min_leaf=self.min_leaf,timeline=self.timeline)

            return self
        else:
            self.prediction_possible = False
            return self

    def predict(self, x):

        if x[self.split_var] <= self.split_val:
            self.lhs.predict(x)
        else:
            self.rhs.predict(x)
        return self.chf


## Random Survival Forest

In [265]:
class RandomSurvivalForest:

    def __init__(self, n_estimators=100, min_leaf=3, unique_deaths=3,
                 n_jobs=None, parallelization_backend="multiprocessing", oob_score=False):

        self.n_estimators = n_estimators
        self.min_leaf = min_leaf
        self.unique_deaths = unique_deaths
        self.n_jobs = n_jobs
        self.parallelization_backend = parallelization_backend
        self.bootstrap_idxs = None
        self.bootstraps = []
        self.oob_idxs = None
        self.oob_score = oob_score
        self.trees = []
        self.timeline = None

    def fit(self, x, y):

        # Check The Timeline Argument Once Again
        self.timeline =  np.sort(y[:,0])  #np.arange(y[:, 0].min(), y[:, 0].max(), 1)
        if self.n_jobs == -1:
            self.n_jobs = multiprocessing.cpu_count()
        elif self.n_jobs is None:
            self.n_jobs = 1
        self.bootstrap_idxs = self.draw_bootstrap_samples(x)

        trees = Parallel(n_jobs=self.n_jobs, backend=self.parallelization_backend)(delayed(SurvivalTree)(x[self.bootstrap_idxs[i], :], 
                                                                                                         y[self.bootstrap_idxs[i], :],
                                                                                                         unique_deaths=self.unique_deaths, 
                                                                                                         min_leaf=self.min_leaf,
                                                                                                         timeline=self.timeline)
                                                                                                        for i in tqdm(range(self.n_estimators)))
        for i in range(len(trees)):
            if trees[i].prediction_possible:
                self.trees.append(trees[i])
                self.bootstraps.append(self.bootstrap_idxs[i])

        if self.oob_score:
            self.oob_score = self.compute_oob_score(x, y)

        return self

    def compute_oob_ensembles(self, xs):

        results = [compute_oob_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees,
                                            bootstraps=self.bootstraps) for sample_idx in range(xs.shape[0])]
        oob_ensemble_chfs = [i for i in results if not (i.size ==0)]
        return oob_ensemble_chfs

    def compute_oob_score(self, x, y):

        oob_ensembles = self.compute_oob_ensembles(x)
        c = concordance_index(y_time=y[:, 0], y_pred=oob_ensembles, y_event=y[:, 1])
        return c

    def predict(self, xs):

        ensemble_chfs = [compute_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees)
                         for sample_idx in range(xs.shape[0])]
        return ensemble_chfs

    def draw_bootstrap_samples(self, data):

        bootstrap_idxs = []
        for i in range(self.n_estimators):
            no_samples = len(data)
            data_rows = range(no_samples)
            bootstrap_idx = np.random.choice(data_rows, no_samples)
            bootstrap_idxs.append(bootstrap_idx)

        return bootstrap_idxs


def compute_ensemble_chf(sample_idx, xs, trees):
    denominator = 0
    numerator = 0
    for b in range(len(trees)):
        sample = xs[sample_idx]
        chf = trees[b].predict(sample)
        denominator = denominator + 1
        numerator = numerator + 1 * chf
    ensemble_chf = numerator / denominator
    return ensemble_chf


def compute_oob_ensemble_chf(sample_idx, xs, trees, bootstraps):
    denominator = 0
    numerator = 0
    for b in range(len(trees)):
        if sample_idx not in bootstraps[b]:
            sample = xs[sample_idx]
            chf = trees[b].predict(sample)
            denominator = denominator + 1
            numerator = numerator + 1 * chf
    if denominator != 0:
        oob_ensemble_chf = numerator / denominator
    else:
        oob_ensemble_chf = pd.Series()
    return oob_ensemble_chf


## Scoring

In [266]:
def concordance_index(y_time, y_pred, y_event):

    predicted_outcome = [x.sum() for x in y_pred]
    possible_pairs = (combinations(range(len(y_pred)), 2))
    concordance = 0
    permissible = 0
    for i,j in possible_pairs:
        t1 = y_time[i]
        t2 = y_time[j]
        e1 = y_event[i]
        e2 = y_event[j]
        predicted_outcome_1 = predicted_outcome[i]
        predicted_outcome_2 = predicted_outcome[j]

        shorter_survival_time_censored = (t1 < t2 and e1 == 0) or (t2 < t1 and e2 == 0)
        t1_equals_t2_and_no_death = (t1 == t2 and (e1 == 0 and e2 == 0))

        if shorter_survival_time_censored or t1_equals_t2_and_no_death:
            continue
        else:
            permissible = permissible + 1
            if t1 != t2:
                if t1 < t2:
                    if predicted_outcome_1 > predicted_outcome_2:
                        concordance = concordance + 1
                        continue
                    elif predicted_outcome_1 == predicted_outcome_2:
                        concordance = concordance + 0.5
                        continue
                elif t2 < t1:
                    if predicted_outcome_2 > predicted_outcome_1:
                        concordance = concordance + 1
                        continue
                    elif predicted_outcome_2 == predicted_outcome_1:
                        concordance = concordance + 0.5
                        continue
            elif t1 == t2:
                if e1 == 1 and e2 == 1:
                    if predicted_outcome_1 == predicted_outcome_2:
                        concordance = concordance + 1
                        continue
                    else:
                        concordance = concordance + 0.5
                        continue
                elif not (e1 == 1 and e2 == 1):
                    if e1 == 1 and predicted_outcome_1 > predicted_outcome_2:
                        concordance = concordance + 1
                        continue
                    elif e2 == 1 and predicted_outcome_2 > predicted_outcome_1:
                        concordance = concordance + 1
                        continue
                    else:
                        concordance = concordance + 0.5
                        continue

    c = concordance / permissible

    return c


In [267]:
veteran = pd.read_csv("veteran (1).csv")

In [268]:
y = veteran.loc[:, ["time","status"]].to_numpy()
X = veteran.drop(["time","status"], axis=1).to_numpy()

In [269]:
rsf = RandomSurvivalForest(n_estimators=10,min_leaf = 15,n_jobs = -1,oob_score = True)

In [270]:
rsf.fit(X, y)


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))






<__main__.RandomSurvivalForest at 0x7fad3d907690>

In [271]:
rsf.oob_score

0.6223013321084061

In [181]:
from tqdm.notebook import tqdm as tqdm

In [203]:
len(rsf.lhs.tree.trees)

AttributeError: ignored

In [67]:
a = M()

In [68]:
a.b

6