Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion causalml/inference/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from .slearner import LRSRegressor, BaseSLearner, BaseSRegressor, BaseSClassifier
from .tlearner import (
XGBTRegressor,
XGBTClassifier,
MLPTRegressor,
BaseTLearner,
BaseTRegressor,
BaseTClassifier,
)
from .xlearner import BaseXLearner, BaseXRegressor, BaseXClassifier
from .rlearner import BaseRLearner, BaseRRegressor, BaseRClassifier, XGBRRegressor
from .rlearner import (
BaseRLearner,
BaseRRegressor,
BaseRClassifier,
XGBRRegressor,
XGBRClassifier,
)
from .tmle import TMLELearner
from .drlearner import BaseDRLearner, BaseDRRegressor, BaseDRClassifier, XGBDRRegressor
51 changes: 50 additions & 1 deletion causalml/inference/meta/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm
from scipy.stats import norm
from sklearn.model_selection import cross_val_predict, KFold, train_test_split
from xgboost import XGBRegressor
from xgboost import XGBClassifier, XGBRegressor

from causalml.inference.meta.base import BaseLearner
from causalml.inference.meta.utils import (
Expand Down Expand Up @@ -693,3 +693,52 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True):
sample_weight_filt_t = sample_weight_filt[w == 1]
self.vars_c[group] = get_weighted_variance(diff_c, sample_weight_filt_c)
self.vars_t[group] = get_weighted_variance(diff_t, sample_weight_filt_t)


class XGBRClassifier(BaseRClassifier):
"""Convenience subclass mirroring :class:`XGBRRegressor` for the
classification case. The outcome learner is an ``XGBClassifier``
(``BaseRClassifier.fit`` calls ``cross_val_predict(method='predict_proba')``)
while the effect learner stays an ``XGBRegressor`` because the
R-loss target is real-valued. See uber/causalml#824.
"""

def __init__(
self,
propensity_learner=ElasticNetPropensityModel(),
ate_alpha=0.05,
control_name=0,
n_fold=5,
random_state=None,
outcome_learner_kwargs=None,
effect_learner_kwargs=None,
):
"""Initialize an R-learner classifier with XGBoost models.

Args:
propensity_learner: see :class:`BaseRClassifier`.
ate_alpha: see :class:`BaseRClassifier`.
control_name: see :class:`BaseRClassifier`.
n_fold: see :class:`BaseRClassifier`.
random_state: forwarded to both XGBoost models.
outcome_learner_kwargs (dict, optional): extra kwargs forwarded
to the underlying ``XGBClassifier`` outcome learner. Use
e.g. ``{"max_depth": 3, "n_estimators": 200}``.
effect_learner_kwargs (dict, optional): extra kwargs forwarded
to the underlying ``XGBRegressor`` effect learner.
"""
# Use explicit kwargs dicts rather than ``*args, **kwargs`` so the two
# models can be tuned independently — they have non-overlapping
# hyperparameter spaces in practice (objective, eval_metric, etc.).
outcome_kwargs = dict(outcome_learner_kwargs or {})
effect_kwargs = dict(effect_learner_kwargs or {})

super().__init__(
outcome_learner=XGBClassifier(random_state=random_state, **outcome_kwargs),
effect_learner=XGBRegressor(random_state=random_state, **effect_kwargs),
propensity_learner=propensity_learner,
ate_alpha=ate_alpha,
control_name=control_name,
n_fold=n_fold,
random_state=random_state,
)
18 changes: 17 additions & 1 deletion causalml/inference/meta/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
else:
from sklearn.utils.testing import ignore_warnings
from tqdm import tqdm
from xgboost import XGBRegressor
from xgboost import XGBClassifier, XGBRegressor

from causalml.inference.meta.base import BaseLearner
from causalml.inference.meta.utils import check_treatment_vector, convert_pd_to_np
Expand Down Expand Up @@ -420,6 +420,22 @@ def __init__(self, ate_alpha=0.05, control_name=0, *args, **kwargs):
)


class XGBTClassifier(BaseTClassifier):
"""Convenience subclass mirroring :class:`XGBTRegressor` for the
classification case. Symmetric with :class:`XGBTRegressor` so users
don't have to wire up ``BaseTClassifier(learner=XGBClassifier(...))``
by hand. See uber/causalml#824.
"""

def __init__(self, ate_alpha=0.05, control_name=0, *args, **kwargs):
"""Initialize a T-learner with two XGBoost classifier models."""
super().__init__(
learner=XGBClassifier(*args, **kwargs),
ate_alpha=ate_alpha,
control_name=control_name,
)


class MLPTRegressor(BaseTRegressor):
def __init__(self, ate_alpha=0.05, control_name=0, *args, **kwargs):
"""Initialize a T-learner with two MLP models."""
Expand Down
121 changes: 121 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BaseTRegressor,
BaseTClassifier,
XGBTRegressor,
XGBTClassifier,
MLPTRegressor,
)
from causalml.inference.meta import BaseXLearner, BaseXClassifier, BaseXRegressor
Expand All @@ -28,6 +29,7 @@
BaseRClassifier,
BaseRRegressor,
XGBRRegressor,
XGBRClassifier,
)
from causalml.inference.meta import TMLELearner
from causalml.inference.meta import BaseDRLearner
Expand Down Expand Up @@ -773,6 +775,125 @@ def test_BaseSClassifier(generate_classification_data):
assert auuc["tau_pred"] > 0.5


def test_XGBTClassifier(generate_classification_data):
"""Regression test for uber/causalml#824.

Asserts that the new ``XGBTClassifier`` convenience subclass exists,
is importable from ``causalml.inference.meta``, behaves like a
``BaseTClassifier`` wired with two ``XGBClassifier`` instances, and
produces a non-trivial AUUC. On master this test would fail at
import-time (the symbol does not exist).
"""
np.random.seed(RANDOM_SEED)

df, x_names = generate_classification_data()

df["treatment_group_key"] = np.where(
df["treatment_group_key"] == CONTROL_NAME, 0, 1
)

df_train, df_test = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED)

# Forward an XGBoost-specific kwarg to make sure the subclass passes
# them through to the underlying XGBClassifier.
uplift_model = XGBTClassifier(n_estimators=20)

uplift_model.fit(
X=df_train[x_names].values,
treatment=df_train["treatment_group_key"].values,
y=df_train[CONVERSION].values,
)

tau_pred = uplift_model.predict(
X=df_test[x_names].values, treatment=df_test["treatment_group_key"].values
)

# Verify the underlying models are XGBClassifier (the load-bearing
# invariant the convenience subclass exists to enforce).
sample_group = next(iter(uplift_model.models_c))
assert isinstance(uplift_model.models_c[sample_group], XGBClassifier)
assert isinstance(uplift_model.models_t[sample_group], XGBClassifier)

auuc_metrics = pd.DataFrame(
{
"tau_pred": tau_pred.flatten(),
"W": df_test["treatment_group_key"].values,
CONVERSION: df_test[CONVERSION].values,
"treatment_effect_col": df_test["treatment_effect"].values,
}
)

auuc = auuc_score(
auuc_metrics,
outcome_col=CONVERSION,
treatment_col="W",
treatment_effect_col="treatment_effect_col",
normalize=True,
)
assert auuc["tau_pred"] > 0.5


def test_XGBRClassifier(generate_classification_data):
"""Regression test for uber/causalml#824 (R-learner counterpart).

Asserts that ``XGBRClassifier`` exists, wires an ``XGBClassifier``
outcome learner and an ``XGBRegressor`` effect learner (the only
correct combination — R-loss has a real-valued target), and produces
a non-trivial AUUC.
"""
np.random.seed(RANDOM_SEED)

df, x_names = generate_classification_data()

df["treatment_group_key"] = np.where(
df["treatment_group_key"] == CONTROL_NAME, 0, 1
)

propensity_model = LogisticRegression()
propensity_model.fit(X=df[x_names].values, y=df["treatment_group_key"].values)
df["propensity_score"] = propensity_model.predict_proba(df[x_names].values)[:, 1]

df_train, df_test = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED)

uplift_model = XGBRClassifier(
outcome_learner_kwargs={"n_estimators": 20},
effect_learner_kwargs={"n_estimators": 20},
random_state=RANDOM_SEED,
)

uplift_model.fit(
X=df_train[x_names].values,
treatment=df_train["treatment_group_key"].values,
y=df_train[CONVERSION].values,
)

# Verify the underlying model types — outcome must be a classifier
# (BaseRClassifier.fit calls cross_val_predict with predict_proba),
# effect must be a regressor (R-loss target is real-valued).
assert isinstance(uplift_model.model_mu, XGBClassifier)
assert isinstance(uplift_model.model_tau, XGBRegressor)

tau_pred = uplift_model.predict(X=df_test[x_names].values)

auuc_metrics = pd.DataFrame(
{
"tau_pred": tau_pred.flatten(),
"W": df_test["treatment_group_key"].values,
CONVERSION: df_test[CONVERSION].values,
"treatment_effect_col": df_test["treatment_effect"].values,
}
)

auuc = auuc_score(
auuc_metrics,
outcome_col=CONVERSION,
treatment_col="W",
treatment_effect_col="treatment_effect_col",
normalize=True,
)
assert auuc["tau_pred"] > 0.5


def test_BaseTClassifier(generate_classification_data):
np.random.seed(RANDOM_SEED)

Expand Down