Skip to content

Commit

Permalink
Merge branch 'master' of github.com:uber/causalml into tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandrmgservices committed Aug 22, 2023
2 parents dc0ac40 + c0e3ec5 commit 3406182
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 42 deletions.
2 changes: 1 addition & 1 deletion causalml/dataset/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def simulate_randomized_trial(n=1000, p=5, sigma=1.0, adj=0.0):
"""

X = np.random.normal(size=n * p).reshape((n, -1))
b = np.maximum(np.repeat(0.0, n), X[:, 0] + X[:, 1], X[:, 2]) + np.maximum(
b = np.maximum.reduce([np.repeat(0.0, n), X[:, 0] + X[:, 1], X[:, 2]]) + np.maximum(
np.repeat(0.0, n), X[:, 3] + X[:, 4]
)
e = np.repeat(0.5, n)
Expand Down
8 changes: 6 additions & 2 deletions causalml/inference/meta/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,20 @@ def get_shap_values(self):

return shap_dict

def plot_importance(self, importance_dict=None, title_prefix=""):
def plot_importance(self, importance_dict=None, title_prefix="", figsize=(12, 8)):
"""
Calculates and plots feature importances for each treatment group, based on specified method in __init__.
Skips the calculation part if importance_dict is given.
Args:
importance_dict (optional, dict): a dict of feature importance matrics. If None, importance_dict will be computed.
title_prefix (optional, str): a prefix to the title of the plot.
figsize (optional, tuple): the size of the figure.
"""
if importance_dict is None:
importance_dict = self.get_importance()
for group, series in importance_dict.items():
plt.figure()
series.sort_values().plot(kind="barh", figsize=(12, 8))
series.sort_values().plot(kind="barh", figsize=figsize)
title = group
if title_prefix != "":
title = "{} - {}".format(title_prefix, title)
Expand Down
103 changes: 83 additions & 20 deletions causalml/inference/tree/uplift.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ class UpliftTreeClassifier:
n_reg: int, optional (default=100)
The regularization parameter defined in Rzepakowski et al. 2012, the weight (in terms of sample size) of the
parent node influence on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
early_stopping_eval_diff_scale: float, optional (default=1)
If train and valid uplift score diff bigger than
min(train_uplift_score,valid_uplift_score)/early_stopping_eval_diff_scale, stop.
control_name: string
The name of the control group (other experiment groups will be regarded as treatment groups).
Expand All @@ -240,12 +244,13 @@ class UpliftTreeClassifier:
"""
def __init__(self, control_name, max_features=None, max_depth=3, min_samples_leaf=100,
min_samples_treatment=10, n_reg=100, evaluationFunction='KL',
min_samples_treatment=10, n_reg=100, early_stopping_eval_diff_scale=1, evaluationFunction='KL',
normalization=True, honesty=False, estimation_sample_size=0.5, random_state=None):
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.min_samples_treatment = min_samples_treatment
self.n_reg = n_reg
self.early_stopping_eval_diff_scale = early_stopping_eval_diff_scale
self.max_features = max_features

assert evaluationFunction in ['KL', 'ED', 'Chi', 'CTS', 'DDP', 'IT', 'CIT', 'IDDP'], \
Expand Down Expand Up @@ -282,7 +287,7 @@ class UpliftTreeClassifier:
self.honesty = True


def fit(self, X, treatment, y):
def fit(self, X, treatment, y, X_val=None, treatment_val=None, y_val=None):
""" Fit the uplift model.
Args
Expand All @@ -306,14 +311,23 @@ class UpliftTreeClassifier:
X, y = check_X_y(X, y)
treatment = np.asarray(treatment)
assert len(y) == len(treatment), 'Data length must be equal for X, treatment, and y.'

if X_val is not None:
X_val, y_val = check_X_y(X_val, y_val)
treatment_val = np.asarray(treatment_val)
assert len(y_val) == len(treatment_val), 'Data length must be equal for X_val, treatment_val, and y_val.'

# Get treatment group keys. self.classes_[0] is reserved for the control group.
treatment_groups = sorted([x for x in list(set(treatment)) if x != self.control_name])
self.classes_ = [self.control_name]
treatment_idx = np.zeros_like(treatment, dtype=int)
treatment_val_idx = None
if treatment_val is not None:
treatment_val_idx = np.zeros_like(treatment_val, dtype=int)
for i, tr in enumerate(treatment_groups, 1):
self.classes_.append(tr)
treatment_idx[treatment == tr] = i
if treatment_val_idx is not None:
treatment_val_idx[treatment_val == tr] = i
self.n_class = len(self.classes_)

self.feature_imp_dict = defaultdict(float)
Expand All @@ -333,8 +347,9 @@ class UpliftTreeClassifier:
random_state=self.random_state)

self.fitted_uplift_tree = self.growDecisionTreeFrom(
X, treatment_idx, y,
max_depth=self.max_depth, min_samples_leaf=self.min_samples_leaf,
X, treatment_idx, y, X_val, treatment_val_idx, y_val,
max_depth=self.max_depth, early_stopping_eval_diff_scale=self.early_stopping_eval_diff_scale,
min_samples_leaf=self.min_samples_leaf,
depth=1, min_samples_treatment=self.min_samples_treatment,
n_reg=self.n_reg, parentNodeSummary=None
)
Expand Down Expand Up @@ -1118,7 +1133,8 @@ class UpliftTreeClassifier:
res.append(p)
return res

def growDecisionTreeFrom(self, X, treatment_idx, y, max_depth=10,
def growDecisionTreeFrom(self, X, treatment_idx, y, X_val, treatment_val_idx, y_val,
early_stopping_eval_diff_scale=1, max_depth=10,
min_samples_leaf=100, depth=1,
min_samples_treatment=10, n_reg=100,
parentNodeSummary=None):
Expand All @@ -1133,6 +1149,12 @@ class UpliftTreeClassifier:
An array containing the treatment group idx for each unit.
y : array-like, shape = [num_samples]
An array containing the outcome of interest for each unit.
X_val : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to valid the uplift model.
treatment_val_idx : array-like, shape = [num_samples]
An array containing the validation treatment group idx for each unit.
y_val : array-like, shape = [num_samples]
An array containing the validation outcome of interest for each unit.
max_depth: int, optional (default=10)
The maximum depth of the tree.
min_samples_leaf: int, optional (default=100)
Expand Down Expand Up @@ -1194,7 +1216,6 @@ class UpliftTreeClassifier:
else:
p_t = currentNodeSummary[suboptTreatment][0]
n_t = currentNodeSummary[suboptTreatment][1]

p_value = (1. - stats.norm.cdf(abs(p_c - p_t) / np.sqrt(p_t * (1 - p_t) / n_t + p_c * (1 - p_c) / n_c))) * 2
upliftScore = [maxDiff, p_value]

Expand Down Expand Up @@ -1223,6 +1244,7 @@ class UpliftTreeClassifier:

for value in lsUnique:
X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X, treatment_idx, y, col, value)

# check the split validity on min_samples_leaf 372
if (len(X_l) < min_samples_leaf or len(X_r) < min_samples_leaf):
continue
Expand All @@ -1233,15 +1255,28 @@ class UpliftTreeClassifier:
min_samples_treatment=min_samples_treatment,
n_reg=n_reg,
parentNodeSummary=currentNodeSummary)

rightNodeSummary = self.tree_node_summary(w_r, y_r,
min_samples_treatment=min_samples_treatment,
min_samples_treatment=min_samples_treatment,
n_reg=n_reg,
parentNodeSummary=currentNodeSummary)

# check the split validity on min_samples_treatment
assert len(leftNodeSummary) == len(rightNodeSummary)

if X_val is not None:
X_val_l, X_val_r, w_val_l, w_val_r, y_val_l, y_val_r = self.divideSet(X_val, treatment_val_idx, y_val, col, value)
leftNodeSummary_val = self.tree_node_summary(w_val_l, y_val_l,
parentNodeSummary=currentNodeSummary)
rightNodeSummary_val = self.tree_node_summary(w_val_r, y_val_r,
parentNodeSummary=currentNodeSummary)
early_stopping_flag = False
for k in range(len(leftNodeSummary_val)):
if (abs(leftNodeSummary_val[k][0]-leftNodeSummary[k][0]) > min(leftNodeSummary_val[k][0],leftNodeSummary[k][0])/early_stopping_eval_diff_scale or
abs(rightNodeSummary_val[k][0]-rightNodeSummary[k][0]) > min(rightNodeSummary_val[k][0],rightNodeSummary[k][0])/early_stopping_eval_diff_scale):
early_stopping_flag = True
break
if early_stopping_flag:
continue

# check the split validity on min_samples_treatment
node_mst = min([stat[1] for stat in leftNodeSummary + rightNodeSummary])
if node_mst < min_samples_treatment:
continue
Expand Down Expand Up @@ -1293,13 +1328,16 @@ class UpliftTreeClassifier:
norm_factor = self.normI(n_c, n_c_left, n_t, n_t_left, alpha=0.9)
else:
norm_factor = 1
gain = gain / norm_factor
gain = gain / norm_factor
if (gain > bestGain and len(X_l) > min_samples_leaf and len(X_r) > min_samples_leaf):
bestGain = gain
bestGainImp = gain_for_imp
bestAttribute = (col, value)
best_set_left = [X_l, w_l, y_l]
best_set_right = [X_r, w_r, y_r]
best_set_left = [X_l, w_l, y_l, None, None, None]
best_set_right = [X_r, w_r, y_r, None, None, None]
if X_val is not None:
best_set_left = [X_l, w_l, y_l, X_val_l, w_val_l, y_val_l]
best_set_right = [X_r, w_r, y_r, X_val_r, w_val_r, y_val_r]

dcY = {'impurity': '%.3f' % currentScore, 'samples': '%d' % len(X)}
# Add treatment size
Expand All @@ -1312,12 +1350,12 @@ class UpliftTreeClassifier:
if bestGain > 0 and depth < max_depth:
self.feature_imp_dict[bestAttribute[0]] += bestGainImp
trueBranch = self.growDecisionTreeFrom(
*best_set_left, max_depth, min_samples_leaf,
*best_set_left, self.early_stopping_eval_diff_scale, max_depth, min_samples_leaf,
depth + 1, min_samples_treatment=min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
falseBranch = self.growDecisionTreeFrom(
*best_set_right, max_depth, min_samples_leaf,
*best_set_right, self.early_stopping_eval_diff_scale, max_depth, min_samples_leaf,
depth + 1, min_samples_treatment=min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
Expand Down Expand Up @@ -1484,6 +1522,10 @@ class UpliftRandomForestClassifier:
weight (in terms of sample size) of the parent node influence on the
child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
early_stopping_eval_diff_scale: float, optional (default=1)
If train and valid uplift score diff bigger than
min(train_uplift_score,valid_uplift_score)/early_stopping_eval_diff_scale, stop.
control_name: string
The name of the control group (other experiment groups will be regarded as treatment groups)
Expand Down Expand Up @@ -1521,6 +1563,7 @@ class UpliftRandomForestClassifier:
min_samples_leaf=100,
min_samples_treatment=10,
n_reg=10,
early_stopping_eval_diff_scale=1,
evaluationFunction='KL',
normalization=True,
honesty=False,
Expand All @@ -1538,6 +1581,7 @@ class UpliftRandomForestClassifier:
self.min_samples_leaf = min_samples_leaf
self.min_samples_treatment = min_samples_treatment
self.n_reg = n_reg
self.early_stopping_eval_diff_scale = early_stopping_eval_diff_scale
self.evaluationFunction = evaluationFunction
self.control_name = control_name
self.normalization = normalization
Expand All @@ -1554,7 +1598,7 @@ class UpliftRandomForestClassifier:
if self.n_jobs == -1:
self.n_jobs = mp.cpu_count()

def fit(self, X, treatment, y):
def fit(self, X, treatment, y, X_val=None, treatment_val=None, y_val=None):
"""
Fit the UpliftRandomForestClassifier.
Expand All @@ -1568,6 +1612,15 @@ class UpliftRandomForestClassifier:
y : array-like, shape = [num_samples]
An array containing the outcome of interest for each unit.
X_val : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to valid the uplift model.
treatment_val : array-like, shape = [num_samples]
An array containing the validation treatment group for each unit.
y_val : array-like, shape = [num_samples]
An array containing the validation outcome of interest for each unit.
"""
random_state = check_random_state(self.random_state)

Expand All @@ -1578,6 +1631,7 @@ class UpliftRandomForestClassifier:
min_samples_leaf=self.min_samples_leaf,
min_samples_treatment=self.min_samples_treatment,
n_reg=self.n_reg,
early_stopping_eval_diff_scale=self.early_stopping_eval_diff_scale,
evaluationFunction=self.evaluationFunction,
control_name=self.control_name,
normalization=self.normalization,
Expand All @@ -1595,21 +1649,30 @@ class UpliftRandomForestClassifier:

self.uplift_forest = (
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
(delayed(self.bootstrap)(X, treatment, y, tree) for tree in self.uplift_forest)
(delayed(self.bootstrap)(X, treatment, y, X_val, treatment_val, y_val, tree) for tree in self.uplift_forest)
)

all_importances = [tree.feature_importances_ for tree in self.uplift_forest]
self.feature_importances_ = np.mean(all_importances, axis=0)
self.feature_importances_ /= self.feature_importances_.sum() # normalize to add to 1

@staticmethod
def bootstrap(X, treatment, y, tree):
def bootstrap(X, treatment, y, X_val, treatment_val, y_val, tree):
random_state = check_random_state(tree.random_state)
bt_index = random_state.choice(len(X), len(X))
x_train_bt = X[bt_index]
y_train_bt = y[bt_index]
treatment_train_bt = treatment[bt_index]
tree.fit(X=x_train_bt, treatment=treatment_train_bt, y=y_train_bt)

if X_val is None:
tree.fit(X=x_train_bt, treatment=treatment_train_bt, y=y_train_bt)
else:
bt_val_index = random_state.choice(len(X_val), len(X_val))
x_val_bt = X_val[bt_val_index]
y_val_bt = y_val[bt_val_index]
treatment_val_bt = treatment_val[bt_val_index]

tree.fit(X=x_train_bt, treatment=treatment_train_bt, y=y_train_bt, X_val=x_val_bt, treatment_val=treatment_val_bt, y_val=y_val_bt)
return tree

@ignore_warnings(category=FutureWarning)
Expand Down
21 changes: 11 additions & 10 deletions envs/environment-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- pip:
- absl-py==0.12.0
- appnope==0.1.2
- argon2-cffi==20.1.0
- argon2-cffi==21.2.0
- astunparse==1.6.3
- async-generator==1.10
- attrs==20.3.0
Expand All @@ -20,18 +20,18 @@ dependencies:
- chardet==4.0.0
- coverage==5.5
- cycler==0.10.0
- cython==0.29.23
- Cython==0.29.34
- decorator==5.0.7
- defusedxml==0.7.1
- dill==0.3.3
- dill==0.3.5.1
- entrypoints==0.3
- future==0.18.2
- gast==0.3.3
- google-auth==1.29.0
- google-auth-oauthlib==0.4.4
- google-pasta==0.2.0
- grpcio==1.37.0
- h5py==2.10.0
- grpcio==1.51.3
- h5py==3.9.0
- idna==2.10
- iniconfig==1.1.1
- ipykernel==5.5.3
Expand All @@ -54,8 +54,8 @@ dependencies:
- jupyterlab-pygments==0.1.2
- jupyterlab-widgets==1.0.0
- kiwisolver==1.3.1
- lightgbm==3.2.1
- llvmlite==0.36.0
- lightgbm==3.3.4
- llvmlite==0.39.0
- lxml==4.6.3
- markdown==3.3.4
- markupsafe==1.1.1
Expand All @@ -66,14 +66,14 @@ dependencies:
- nbformat==5.1.3
- nest-asyncio==1.5.1
- notebook==6.3.0
- numba==0.53.1
- numba==0.56.0
- oauthlib==3.1.0
- opt-einsum==3.3.0
- packaging==20.9
- pandas==1.2.4
- pandocfilters==1.4.3
- parso==0.8.2
- patsy==0.5.1
- patsy==0.5.2
- pexpect==4.8.0
- pickleshare==0.7.5
- pillow==8.2.0
Expand Down Expand Up @@ -133,4 +133,5 @@ dependencies:
- widgetsnbextension==3.5.1
- wrapt==1.12.1
- xgboost==1.4.1
- causalml>=0.10.0
## Relative path to setup.py directory for full local installation
- ../.

0 comments on commit 3406182

Please sign in to comment.