From d4a19a2a9728f13cd463df97a722affbf5734f7e Mon Sep 17 00:00:00 2001 From: Ian Delbridge <67915427+IanDelbridge@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:18:21 -0400 Subject: [PATCH] 20240207 honest leaf size (#753) * refactor to just use fillTree for honestApproach --- causalml/inference/tree/uplift.pyx | 31 +----------------------------- 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index dfcbd263..1f60b908 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -564,36 +564,7 @@ class UpliftTreeClassifier: An array containing the outcome of interest for each unit. """ - self.modifyEstimation(X_est, T_est, Y_est, self.fitted_uplift_tree) - - def modifyEstimation(self, X_est, t_est, y_est, tree): - """ Modifies the leafs of the current decision tree to only contain unbiased estimates. - Applies the honest approach based on "Athey, S., & Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects." - Args - ---- - X_est : ndarray, shape = [num_samples, num_features] - An ndarray of the covariates used to calculate the unbiased estimates in the leafs of the decision tree. - T_est : array-like, shape = [num_samples] - An array containing the treatment group for each unit. - Y_est : array-like, shape = [num_samples] - An array containing the outcome of interest for each unit. - tree : object - object of DecisionTree class - the current decision tree that shall be modified - """ - - # Divide sets for child nodes - if tree.trueBranch or tree.falseBranch: - X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X_est, t_est, y_est, tree.col, tree.value) - - # recursive call for each branch - if tree.trueBranch is not None: - self.modifyEstimation(X_l, w_l, y_l, tree.trueBranch) - if tree.falseBranch is not None: - self.modifyEstimation(X_r, w_r, y_r, tree.falseBranch) - - # classProb - if tree.results is not None: - tree.results = self.uplift_classification_results(t_est, y_est) + self.fillTree(X_est, T_est, Y_est, self.fitted_uplift_tree) def pruneTree(self, X, treatment_idx, y, tree, rule='maxAbsDiff', minGain=0., n_reg=0,