Skip to content

Commit

Permalink
Add with_std, with_counts to create_table_one (#748)
Browse files Browse the repository at this point in the history
* Issue 747: Add with_std, with_counts to create_table_one
  • Loading branch information
lee-junseok committed Apr 12, 2024
1 parent d41c40c commit 18c323d
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions causalml/match.py
Expand Up @@ -31,7 +31,7 @@ def smd(feature, treatment):
return (t.mean() - c.mean()) / np.sqrt(0.5 * (t.var() + c.var()))


def create_table_one(data, treatment_col, features):
def create_table_one(data, treatment_col, features, with_std=True, with_counts=True):
"""Report balance in input features between the treatment and control groups.
References:
Expand All @@ -42,6 +42,8 @@ def create_table_one(data, treatment_col, features):
data (pandas.DataFrame): total or matched sample data
treatment_col (str): the column name for the treatment
features (list of str): the column names of features
with_std (bool): whether to output std together with mean values as in <mean> (<std>) format
with_counts (bool): whether to include a row counting the total number of samples
Returns:
(pandas.DataFrame): A table with the means and standard deviations in
Expand All @@ -51,19 +53,27 @@ def create_table_one(data, treatment_col, features):
t1 = pd.pivot_table(
data[features + [treatment_col]],
columns=treatment_col,
aggfunc=[lambda x: "{:.2f} ({:.2f})".format(x.mean(), x.std())],
aggfunc=[
lambda x: (
"{:.2f} ({:.2f})".format(x.mean(), x.std())
if with_std
else "{:.2f}".format(x.mean())
)
],
)
t1.columns = t1.columns.droplevel(level=0)
t1["SMD"] = data[features].apply(lambda x: smd(x, data[treatment_col])).round(4)

n_row = pd.pivot_table(
data[[features[0], treatment_col]], columns=treatment_col, aggfunc=["count"]
)
n_row.columns = n_row.columns.droplevel(level=0)
n_row["SMD"] = ""
n_row.index = ["n"]
if with_counts:
n_row = pd.pivot_table(
data[[features[0], treatment_col]], columns=treatment_col, aggfunc=["count"]
)
n_row.columns = n_row.columns.droplevel(level=0)
n_row["SMD"] = ""
n_row.index = ["n"]

t1 = pd.concat([n_row, t1], axis=0)

t1 = pd.concat([n_row, t1], axis=0)
t1.columns.name = ""
t1.columns = ["Control", "Treatment", "SMD"]
t1.index.name = "Variable"
Expand Down

0 comments on commit 18c323d

Please sign in to comment.