In [1]:
from wildwood.tree import  TreeSurvival
from wildwood._utils import criteria_mapping
from wildwood.preprocessing import Encoder
import numpy as np
from wildwood.forest import _generate_train_valid_samples
from sklearn.utils.validation import check_consistent_length, _check_sample_weight
from wildwood.preprocessing._checks import get_is_categorical
from typing import Union
from wildwood.score import _estimate_concordance_index

In [2]:
# # TODO: Add arguments into function
n_features = 2
n_groups = 3
n_samples = 1000
p = np.array([.3, .5, .2])
n_samples_group = (p * n_samples).astype(int)
G = []
for g in range(n_groups):
    G += [g] * n_samples_group[g]
G = np.array(G)

# create features X
X = np.zeros((n_samples, n_features))
X[:, 0][G == 0] = np.random.normal(6, 1, n_samples_group[0])
X[:, 1][G == 0] = np.random.normal(-2, 1, n_samples_group[0])
X[:, 0][(G == 1) | (G == 2)] = np.random.normal(-1, 1, n_samples_group[1] + n_samples_group[2])
X[:, 1][G == 2] = np.random.normal(4, 1, n_samples_group[2])
X[:, 1][G == 1] = np.random.normal(-2, 1, n_samples_group[1])

# create survival-time and censoring indicator
Y = np.zeros(n_samples)
delta = np.ones(n_samples, int)
Y[G == 0] = np.random.normal(30, 2, n_samples_group[0])
Y[G == 1] = np.random.normal(20, 2, n_samples_group[1])
Y[G == 2] = np.random.normal(10, 2, n_samples_group[2])

In [3]:
# hyper-params
criterion: str = "logrank"
loss: str = "brier"
step: float = 1.0
aggregation: bool = False
max_depth: Union[None, int] = 2
min_samples_split: int = 15
min_samples_leaf: int = 5
categorical_features = None
max_features: Union[str, int] = 2
handle_unknown = "error"
cat_min_categories = "log"
subsample = int(2e5)
verbose: bool = False
random_state = 42
max_bins = 10
sample_weight = None
categorical_features = None

# setting
check_consistent_length(X, Y)
sample_weight_ = _check_sample_weight(sample_weight, X, dtype=np.float32)
Y = np.ascontiguousarray(Y, dtype=np.float32)
delta = np.ascontiguousarray(delta, dtype=np.int32)
n_samples, n_features = X.shape
max_depth_ = np.iinfo(np.uintp).max if max_depth is None else max_depth
max_features_ = n_features
is_categorical = get_is_categorical(categorical_features, n_features)

# encoder
encoder = Encoder(
    max_bins=max_bins,
    subsample=subsample,
    is_categorical=is_categorical,
    cat_min_categories=cat_min_categories,
    handle_unknown=handle_unknown,
)
encoder.fit(X)
is_categorical_ = encoder.is_categorical_
features_bitarray = encoder.transform(X)

In [4]:
#learner
learner = TreeSurvival(
            criterion=criteria_mapping[criterion],
            loss=loss,
            step=step,
            aggregation=aggregation,
            max_depth=max_depth_,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            categorical_features=categorical_features,
            is_categorical=is_categorical_,
            max_features=max_features_,
            random_state=random_state,
            verbose=verbose,
                )
sample_weight = sample_weight_.copy()
train_indices, valid_indices, train_indices_count = _generate_train_valid_samples(0, n_samples)
sample_weight[train_indices] *= train_indices_count

In [5]:
# fit
learner.fit(features_bitarray, Y, delta, train_indices, valid_indices, sample_weight)

TreeSurvival(aggregation=False, criterion=4,
             is_categorical=array([False, False]), loss='brier', max_depth=2,
             max_features=2, min_samples_leaf=5, min_samples_split=15,
             random_state=42, verbose=False)

In [6]:
# predict
pred_ = learner.predict(features_bitarray)
# score
score_ = _estimate_concordance_index(delta, Y, -pred_, np.ones(len(delta)))
score = max(score_, 1 - score_)

In [7]:
learner.get_nodes()

Unnamed: 0,node_id,parent,left_child,right_child,is_leaf,is_left,depth,feature,threshold,bin_threshold,...,w_samples_valid,start_train,end_train,start_valid,end_valid,is_split_categorical,bin_partition_start,bin_partition_end,bin_partition,y_pred
0,0,18446744073709551614,2,1,False,False,0,1,0.42,7,...,0.0,0,650,0,350,False,0,0,7,7.054957
1,1,0,-1,-1,True,False,1,18446744073709551614,-2.0,18446744073709551614,...,0.0,509,650,291,350,False,0,0,18446744073709551614,5.529518
2,2,0,4,3,False,False,1,0,0.42,6,...,0.0,0,509,0,291,False,0,0,6,6.810646
3,3,2,-1,-1,True,False,2,18446744073709551614,-2.0,18446744073709551614,...,0.0,314,509,186,291,False,0,0,18446744073709551614,5.852777
4,4,2,-1,-1,True,False,2,18446744073709551614,-2.0,18446744073709551614,...,0.0,0,314,0,186,False,0,0,18446744073709551614,6.3282


In [8]:
# Spliting value of node 0
encoder.binning_thresholds_[1][7]

1.5574034221217474

In [9]:
# Spliting value of node 2
encoder.binning_thresholds_[0][6]

2.7308100293369915

In [11]:
from bokeh.plotting import show
from wildwood.plot import plot_tree
fig = plot_tree(learner)
show(fig)