Skip to content
This repository has been archived by the owner on May 18, 2024. It is now read-only.

Commit

Permalink
Increase spped of optimal split search.
Browse files Browse the repository at this point in the history
This commit will increase the speed of the optimal split search
given a feature axis, by decreasing the number of sum and mean
calls.
  • Loading branch information
timmens committed Feb 10, 2020
1 parent 74f4c10 commit aca31b1
Showing 1 changed file with 147 additions and 56 deletions.
203 changes: 147 additions & 56 deletions causaltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,10 @@

import numpy as np
import pandas as pd
from numba import njit


def _compute_child_node_ids(parent_id):
"""
TODO: Write docstring.
:param parent_id:
:return:
"""
left = 2 * parent_id + 1
right = left + 1

return left, right


def _compute_potential_splitting_points(t, min_leaf):
def _compute_valid_splitting_indices(t, min_leaf):
"""
Computes potential splitting point indices.
Expand All @@ -29,38 +17,39 @@ def _compute_potential_splitting_points(t, min_leaf):
"""
nn = len(t)
if nn <= min_leaf:
return range(0)
return np.arange(0)

tmp = np.where(np.cumsum(t) == min_leaf)[0]
if tmp.size == 0:
return range(0)
return np.arange(0)
else:
left_treated = tmp[0]
tmp = np.where(np.cumsum(~t) == min_leaf)[0]
if tmp.size == 0:
return range(0)
return np.arange(0)
else:
left_untreated = tmp[0]
left = max(left_treated, left_untreated)

tmp = np.where(np.cumsum(np.flip(t)) == min_leaf)[0]
if tmp.size == 0:
return range(0)
return np.arange(0)
else:
right_treated = tmp[0]
tmp = np.where(np.cumsum(np.flip(~t)) == min_leaf)[0]
if tmp.size == 0:
return range(0)
return np.arange(0)
else:
right_untreated = tmp[0]
right = nn - 1 - max(right_treated, right_untreated)

if left > right - 1:
return range(0)
return np.arange(0)
else:
return range(left, right - 1)
return np.arange(left, right - 1)


@njit
def _transform_outcome(y, t):
"""
Transforms outcome using naive propensity scores (Prob[`t`=1] = 1/2).
Expand All @@ -75,6 +64,7 @@ def _transform_outcome(y, t):
return y_transformed


@njit
def _estimate_treatment_effect(y, t):
"""
Estimates average treatment effect (ATE) using outcomes `y` and treatment
Expand All @@ -87,6 +77,7 @@ def _estimate_treatment_effect(y, t):
return y[t].mean() - y[~t].mean()


@njit
def _weight_loss(left_loss, right_loss, n_left, n_right):
"""
Given loss in a left leaf (`left_loss`) and right leaf (`right_loss`) and
Expand All @@ -106,6 +97,7 @@ def _weight_loss(left_loss, right_loss, n_left, n_right):
return left + right


@njit
def _retrieve_index(index, index_sorted, split_index):
"""
Given `index` (bool index of length of the original training data (n)),
Expand All @@ -128,12 +120,124 @@ def _retrieve_index(index, index_sorted, split_index):
right_index = np.full((n,), False)
left_index[nonzero_index[left]] = True
right_index[nonzero_index[right]] = True
global_split_index = nonzero_index[index_sorted[split_index]]
# global_split_index = nonzero_index[index_sorted[split_index]]

return left_index, right_index


@njit
def _compute_treatment_effect_raw(sum_1, n_1, sum_0, n_0):
"""
Args:
sum_1:
n_1:
sum_0:
n_0:
Returns:
"""
return sum_1 / n_1 - sum_0 / n_0


@njit
def _compute_loss_raw_left(yy_transformed, i, te):
"""
Args:
yy_transformed:
i:
te:
return left_index, right_index, global_split_index
Returns:
"""
return te ** 2 - 2 * te * yy_transformed[: (i + 1)].sum()


def _find_optimal_split(y, t, x, index, metric, loss_weighting, min_leaf):
@njit
def _compute_loss_raw_right(yy_transformed, i, te):
"""
Args:
yy_transformed:
i:
te:
Returns:
"""
return te ** 2 - 2 * te * yy_transformed[(i + 1) :].sum()


@njit
def _find_optimal_split_observation_loop(
splitting_indices, yy, yy_transformed, xx, tt, loss
):
if len(splitting_indices) == 0:
return loss, None, None

split_value = None
split_index = None
squared_sum_transformed = (yy_transformed ** 2).sum()
minimal_loss = loss - squared_sum_transformed

i0 = splitting_indices[0]
n_1l = np.sum(tt[: (i0 + 1)])
n_0l = np.sum(~tt[: (i0 + 1)])
n_1r = np.sum(tt[(i0 + 1) :])
n_0r = len(tt) - n_1l - n_0l - n_1r

sum_1l = yy[tt][: (i0 + 1)].sum()
sum_0l = yy[~tt][: (i0 + 1)].sum()
sum_1r = yy[tt][(i0 + 1) :].sum()
sum_0r = yy[~tt][(i0 + 1) :].sum()

left_te = _compute_treatment_effect_raw(sum_1l, n_1l, sum_0l, n_0l)
right_te = _compute_treatment_effect_raw(sum_1r, n_1r, sum_0r, n_0r)

left_loss = _compute_loss_raw_left(yy_transformed, i0, left_te)
right_loss = _compute_loss_raw_right(yy_transformed, i0, right_te)

global_loss = left_loss + right_loss
if global_loss < minimal_loss:
split_value = xx[i0]
split_index = i0
minimal_loss = global_loss

for i in splitting_indices[1:]:

if tt[i]:
sum_1l += yy[i]
sum_1r -= yy[i]
n_1l += 1
n_1r -= 1
else:
sum_0l += yy[i]
sum_0r -= yy[i]
n_0l += 1
n_0r -= 1

left_te = _compute_treatment_effect_raw(sum_1l, n_1l, sum_0l, n_0l)
right_te = _compute_treatment_effect_raw(sum_1r, n_1r, sum_0r, n_0r)

left_loss = _compute_loss_raw_left(yy_transformed, i, left_te)
right_loss = _compute_loss_raw_right(yy_transformed, i, right_te)

global_loss = left_loss + right_loss
# global_loss = loss_weighting(
# left_loss, right_loss, i + 1, len(yy) - i - 1
# )
if global_loss < minimal_loss:
split_value = xx[i]
split_index = i
minimal_loss = global_loss

return minimal_loss + squared_sum_transformed, split_value, split_index


def _find_optimal_split(y, t, x, index, min_leaf):
"""
Args:
Expand All @@ -149,7 +253,7 @@ def _find_optimal_split(y, t, x, index, metric, loss_weighting, min_leaf):
"""
_, p = x.shape
split_var = None
split_feat = None
split_value = None
split_index = None
loss = np.inf
Expand All @@ -158,42 +262,37 @@ def _find_optimal_split(y, t, x, index, metric, loss_weighting, min_leaf):
# loop through features

index_sorted = np.argsort(x[index, j])
xx = x[index, j][index_sorted]
yy = y[index][index_sorted]
xx = x[index, j][index_sorted]
tt = t[index][index_sorted]

yy_transformed = _transform_outcome(yy, tt)

splitting_points = _compute_potential_splitting_points(tt, min_leaf)
for i in splitting_points:
# loop through observations

left_te = _estimate_treatment_effect(yy[: (i + 1)], tt[: (i + 1)])
right_te = _estimate_treatment_effect(yy[(i + 1) :], tt[(i + 1) :])
splitting_indices = _compute_valid_splitting_indices(tt, min_leaf)

left_loss = metric(yy_transformed[: (i + 1)], left_te)
right_loss = metric(yy_transformed[(i + 1) :], right_te)
(
jloss,
jsplit_value,
jsplit_index,
) = _find_optimal_split_observation_loop(
splitting_indices, yy, yy_transformed, xx, tt, loss
)

global_loss = loss_weighting(
left_loss, right_loss, i + 1, len(yy) - i - 1
)
if global_loss < loss:
split_var = j
split_value = xx[i]
split_index = i
loss = global_loss
if jloss < loss:
split_feat = j
split_value = jsplit_value
split_index = jsplit_index
loss = jloss

# check if any split has occured.
if loss == np.inf:
return None

# create index of observations falling in left and right leaf, respectively
index_sorted = np.argsort(x[index, split_var])
left, right, split_index = _retrieve_index(
index, index_sorted, split_index
)
index_sorted = np.argsort(x[index, split_feat])
left, right = _retrieve_index(index, index_sorted, split_index)

return left, right, split_var, split_value
return left, right, split_feat, split_value


def _fit_node(y, t, x, index, crit_params, func_params, id_params):
Expand Down Expand Up @@ -228,15 +327,7 @@ def _fit_node(y, t, x, index, crit_params, func_params, id_params):

df_out = pd.DataFrame(columns=column_names)

tmp = _find_optimal_split(
y,
t,
x,
index,
func_params["metric"],
func_params["weight_loss"],
crit_params["min_leaf"],
)
tmp = _find_optimal_split(y, t, x, index, crit_params["min_leaf"],)

if tmp is None or level == crit_params["max_depth"]:
# if we do not split the node must be a leaf, hence we add the
Expand Down

0 comments on commit aca31b1

Please sign in to comment.