Skip to content

Commit

Permalink
fix #771 and add a test for get/plot_tmlegain()
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed May 9, 2024
1 parent cc30c4a commit c891f46
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
20 changes: 12 additions & 8 deletions causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,16 @@ def get_tmlegain(
lift_ub = []

for col in model_names:
# Create `n_segment` equal segments from sorted model estimates. Rank is used to break ties.
# ref: https://stackoverflow.com/a/46979206/3216742
segments = pd.qcut(df[col].rank(method="first"), n_segment, labels=False)

ate_model, ate_model_lb, ate_model_ub = tmle.estimate_ate(
X=df[inference_col],
p=df[p_col],
treatment=df[treatment_col],
y=df[outcome_col],
segment=pd.qcut(df[col], n_segment, labels=False),
segment=segments,
)
lift_model = [0.0] * (n_segment + 1)
lift_model[n_segment] = ate_all[0]
Expand Down Expand Up @@ -446,19 +450,21 @@ def get_tmleqini(
qini_ub = []

for col in model_names:
# Create `n_segment` equal segments from sorted model estimates. Rank is used to break ties.
# ref: https://stackoverflow.com/a/46979206/3216742
segments = pd.qcut(df[col].rank(method="first"), n_segment, labels=False)

ate_model, ate_model_lb, ate_model_ub = tmle.estimate_ate(
X=df[inference_col],
p=df[p_col],
treatment=df[treatment_col],
y=df[outcome_col],
segment=pd.qcut(df[col], n_segment, labels=False),
segment=segments,
)

qini_model = [0]
for i in range(1, n_segment):
n_tr = df[pd.qcut(df[col], n_segment, labels=False) == (n_segment - i)][
treatment_col
].sum()
n_tr = df[segments == (n_segment - i)][treatment_col].sum()
qini_model.append(ate_model[0][n_segment - i] * n_tr)

qini.append(qini_model)
Expand All @@ -467,9 +473,7 @@ def get_tmleqini(
qini_lb_model = [0]
qini_ub_model = [0]
for i in range(1, n_segment):
n_tr = df[pd.qcut(df[col], n_segment, labels=False) == (n_segment - i)][
treatment_col
].sum()
n_tr = df[segments == (n_segment - i)][treatment_col].sum()
qini_lb_model.append(ate_model_lb[0][n_segment - i] * n_tr)
qini_ub_model.append(ate_model_ub[0][n_segment - i] * n_tr)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "causalml"
version = "0.15.1"
version = "0.15.2dev"
description = "Python Package for Uplift Modeling and Causal Inference with Machine Learning Algorithms"
readme = { file = "README.md", content-type = "text/markdown" }

Expand Down
64 changes: 62 additions & 2 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import pytest
from causalml.metrics.visualize import get_cumlift
from sklearn.model_selection import KFold, train_test_split

from causalml.metrics.visualize import get_cumlift, plot_tmlegain
from causalml.inference.meta import LRSRegressor


def test_visualize_get_cumlift_errors_on_nan():
Expand All @@ -12,3 +16,59 @@ def test_visualize_get_cumlift_errors_on_nan():

with pytest.raises(Exception):
get_cumlift(df)


def test_plot_tmlegain(generate_regression_data, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)

y, X, treatment, tau, b, e = generate_regression_data()

(
X_train,
X_test,
y_train,
y_test,
e_train,
e_test,
treatment_train,
treatment_test,
tau_train,
tau_test,
b_train,
b_test,
) = train_test_split(X, y, e, treatment, tau, b, test_size=0.5, random_state=42)

learner = LRSRegressor()
learner.fit(X_train, treatment_train, y_train)
cate_test = learner.predict(X_test, treatment_test).flatten()

df = pd.DataFrame(
{
"y": y_test,
"w": treatment_test,
"p": e_test,
"S-Learner": cate_test,
"Actual": tau_test,
}
)

inference_cols = []
for i in range(X_test.shape[1]):
col = "col_" + str(i)
df[col] = X_test[:, i]
inference_cols.append(col)

n_fold = 3
kf = KFold(n_splits=n_fold)

plot_tmlegain(
df,
inference_col=inference_cols,
outcome_col="y",
treatment_col="w",
p_col="p",
n_segment=5,
cv=kf,
calibrate_propensity=True,
ci=False,
)

0 comments on commit c891f46

Please sign in to comment.