Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions causalml/inference/meta/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,17 @@ def predict(
X, treatment, y = convert_pd_to_np(X, treatment, y)

te = np.zeros((X.shape[0], self.t_groups.shape[0]))
yhat_cs = {}
yhat_ts = {}

# models_mu_c is fold-specific but not group-specific; predict once and reuse.
yhat_c = np.r_[[model.predict(X) for model in self.models_mu_c]].mean(axis=0)
# Shared-reference dict preserves the public yhat_cs[group] API cheaply.
yhat_cs = {group: yhat_c for group in self.t_groups}

for i, group in enumerate(self.t_groups):
models_tau = self.models_tau[group]
_te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0)
te[:, i] = np.ravel(_te)
yhat_cs[group] = np.r_[
[model.predict(X) for model in self.models_mu_c]
].mean(axis=0)
yhat_ts[group] = np.r_[
[model.predict(X) for model in self.models_mu_t[group]]
].mean(axis=0)
Expand All @@ -256,7 +257,7 @@ def predict(
w = (treatment_filt == group).astype(int)

yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 0] = yhat_c[mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]

logger.info("Error metrics for group {}".format(group))
Expand Down Expand Up @@ -595,16 +596,18 @@ def predict(
X, treatment, y = convert_pd_to_np(X, treatment, y)

te = np.zeros((X.shape[0], self.t_groups.shape[0]))
yhat_cs = {}
yhat_ts = {}

# models_mu_c is fold-specific but not group-specific; predict once and reuse.
yhat_c = np.r_[
[model.predict_proba(X)[:, 1] for model in self.models_mu_c]
].mean(axis=0)
yhat_cs = {group: yhat_c for group in self.t_groups}

for i, group in enumerate(self.t_groups):
models_tau = self.models_tau[group]
_te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0)
te[:, i] = np.ravel(_te)
yhat_cs[group] = np.r_[
[model.predict_proba(X)[:, 1] for model in self.models_mu_c]
].mean(axis=0)
yhat_ts[group] = np.r_[
[model.predict_proba(X)[:, 1] for model in self.models_mu_t[group]]
].mean(axis=0)
Expand All @@ -616,7 +619,7 @@ def predict(
w = (treatment_filt == group).astype(int)

yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 0] = yhat_c[mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]

logger.info("Error metrics for group {}".format(group))
Expand Down
24 changes: 10 additions & 14 deletions causalml/inference/meta/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,15 @@ def predict(
yhat_cs = {}
yhat_ts = {}

# Build the augmented arrays once; they are identical for every group.
# (Separate allocations avoid in-place mutation by learners like CatBoost
# that set the writeable flag to False on arrays passed to predict().)
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))

for group in self.t_groups:
model = self.models[group]

# Build separate arrays for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict(X_new_c)

X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
yhat_ts[group] = model.predict(X_new_t)

if (y is not None) and (treatment is not None) and verbose:
Expand Down Expand Up @@ -344,16 +343,13 @@ def predict(
yhat_cs = {}
yhat_ts = {}

# Build the augmented arrays once; they are identical for every group.
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))

for group in self.t_groups:
model = self.models[group]

# Build separate arrays for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict_proba(X_new_c)[:, 1]

X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
yhat_ts[group] = model.predict_proba(X_new_t)[:, 1]

if y is not None and (treatment is not None) and verbose:
Expand Down
61 changes: 34 additions & 27 deletions causalml/inference/meta/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
else:
self.model_c = control_learner

# Preserve the unfitted template so repeated fit() calls always start fresh.
self._model_c_template = self.model_c

if treatment_learner is None:
self.model_t = deepcopy(learner)
else:
Expand Down Expand Up @@ -82,18 +85,20 @@ def fit(self, X, treatment, y, p=None):
self.t_groups = np.unique(treatment[treatment != self.control_name])
self.t_groups.sort()
self._classes = {group: i for i, group in enumerate(self.t_groups)}
self.models_c = {group: deepcopy(self.model_c) for group in self.t_groups}
self.models_t = {group: deepcopy(self.model_t) for group in self.t_groups}

for group in self.t_groups:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
X_filt = X[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)
# model_c is trained on the control group, which is identical for every
# treatment group, so fit it once. Deepcopy from the unfitted template so
# re-calling fit() always starts from a clean state (safe with warm_start).
control_mask = treatment == self.control_name
self.model_c = deepcopy(self._model_c_template)
self.model_c.fit(X[control_mask], y[control_mask])
# Expose as a shared-reference dict to preserve the public models_c API.
self.models_c = {group: self.model_c for group in self.t_groups}

self.models_c[group].fit(X_filt[w == 0], y_filt[w == 0])
self.models_t[group].fit(X_filt[w == 1], y_filt[w == 1])
for group in self.t_groups:
treatment_mask = treatment == group
self.models_t[group].fit(X[treatment_mask], y[treatment_mask])

def predict(
self, X, treatment=None, y=None, p=None, return_components=False, verbose=True
Expand All @@ -110,14 +115,15 @@ def predict(
(numpy.ndarray): Predictions of treatment effects.
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
yhat_cs = {}
yhat_ts = {}

yhat_c = self.model_c.predict(X)
# Build a shared-reference dict so return_components callers keep the
# yhat_cs[group] indexing API without duplicating the underlying array.
yhat_cs = {group: yhat_c for group in self.t_groups}

for group in self.t_groups:
model_c = self.models_c[group]
model_t = self.models_t[group]
yhat_cs[group] = model_c.predict(X)
yhat_ts[group] = model_t.predict(X)
yhat_ts[group] = self.models_t[group].predict(X)

if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
Expand All @@ -126,15 +132,15 @@ def predict(
w = (treatment_filt == group).astype(int)

yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 0] = yhat_c[mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]

logger.info("Error metrics for group {}".format(group))
regression_metrics(y_filt, yhat, w)

te = np.zeros((X.shape[0], self.t_groups.shape[0]))
for i, group in enumerate(self.t_groups):
te[:, i] = yhat_ts[group] - yhat_cs[group]
te[:, i] = yhat_ts[group] - yhat_c

if not return_components:
return te
Expand Down Expand Up @@ -178,7 +184,7 @@ def fit_predict(
else:
t_groups_global = self.t_groups
_classes_global = self._classes
models_c_global = deepcopy(self.models_c)
model_c_global = deepcopy(self.model_c)
models_t_global = deepcopy(self.models_t)
te_bootstraps = np.zeros(
shape=(X.shape[0], self.t_groups.shape[0], n_bootstraps)
Expand All @@ -197,7 +203,8 @@ def fit_predict(
# set member variables back to global (currently last bootstrapped outcome)
self.t_groups = t_groups_global
self._classes = _classes_global
self.models_c = deepcopy(models_c_global)
self.model_c = deepcopy(model_c_global)
self.models_c = {group: self.model_c for group in self.t_groups}
self.models_t = deepcopy(models_t_global)

return (te, te_lower, te_upper)
Expand Down Expand Up @@ -271,7 +278,7 @@ def estimate_ate(
else:
t_groups_global = self.t_groups
_classes_global = self._classes
models_c_global = deepcopy(self.models_c)
model_c_global = deepcopy(self.model_c)
models_t_global = deepcopy(self.models_t)

logger.info("Bootstrap Confidence Intervals for ATE")
Expand All @@ -291,7 +298,8 @@ def estimate_ate(
# set member variables back to global (currently last bootstrapped outcome)
self.t_groups = t_groups_global
self._classes = _classes_global
self.models_c = deepcopy(models_c_global)
self.model_c = deepcopy(model_c_global)
self.models_c = {group: self.model_c for group in self.t_groups}
self.models_t = deepcopy(models_t_global)

return ate, ate_lower, ate_upper
Expand Down Expand Up @@ -371,14 +379,13 @@ def predict(
Returns:
(numpy.ndarray): Predictions of treatment effects.
"""
yhat_cs = {}
yhat_ts = {}

yhat_c = self.model_c.predict_proba(X)[:, 1]
yhat_cs = {group: yhat_c for group in self.t_groups}

for group in self.t_groups:
model_c = self.models_c[group]
model_t = self.models_t[group]
yhat_cs[group] = model_c.predict_proba(X)[:, 1]
yhat_ts[group] = model_t.predict_proba(X)[:, 1]
yhat_ts[group] = self.models_t[group].predict_proba(X)[:, 1]

if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
Expand All @@ -387,15 +394,15 @@ def predict(
w = (treatment_filt == group).astype(int)

yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 0] = yhat_c[mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]

logger.info("Error metrics for group {}".format(group))
classification_metrics(y_filt, yhat, w)

te = np.zeros((X.shape[0], self.t_groups.shape[0]))
for i, group in enumerate(self.t_groups):
te[:, i] = yhat_ts[group] - yhat_cs[group]
te[:, i] = yhat_ts[group] - yhat_c

if not return_components:
return te
Expand Down
Loading