Skip to content

Commit

Permalink
expose n_jobs for rlearner (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiJiaW committed Dec 1, 2023
1 parent 3615bc8 commit bae55f5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion causalml/inference/meta/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
control_name=0,
n_fold=5,
random_state=None,
cv_n_jobs=-1,
):
"""Initialize an R-learner.
Expand All @@ -52,6 +53,7 @@ def __init__(
control_name (str or int, optional): name of control group
n_fold (int, optional): the number of cross validation folds for outcome_learner
random_state (int or RandomState, optional): a seed (int) or random number generator (RandomState)
cv_n_jobs (int, optional): number of parallel jobs to run for cross_val_predict. -1 means using all processors
"""
assert (learner is not None) or (
(outcome_learner is not None) and (effect_learner is not None)
Expand All @@ -71,6 +73,7 @@ def __init__(

self.random_state = random_state
self.cv = KFold(n_splits=n_fold, shuffle=True, random_state=random_state)
self.cv_n_jobs = cv_n_jobs

self.propensity = None
self.propensity_model = None
Expand Down Expand Up @@ -120,7 +123,7 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True):

if verbose:
logger.info("generating out-of-fold CV outcome estimates")
yhat = cross_val_predict(self.model_mu, X, y, cv=self.cv, n_jobs=-1)
yhat = cross_val_predict(self.model_mu, X, y, cv=self.cv, n_jobs=self.cv_n_jobs)

for group in self.t_groups:
mask = (treatment == group) | (treatment == self.control_name)
Expand Down

0 comments on commit bae55f5

Please sign in to comment.