Skip to content

Commit

Permalink
Merge pull request uber#663 from ras44/ras44/651_GraphViz
Browse files Browse the repository at this point in the history
linted with black
  • Loading branch information
vincewu51 committed Aug 24, 2023
2 parents 41aa6bd + 3d9d49b commit b9eb51e
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions tests/test_uplift_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,32 +250,40 @@ def getNonleafCount(node):
# shouldn't be larger than the number of non-leaf node
assert num_non_zero_imp_features <= num_non_leaf_nodes

def test_uplift_tree_visualization():

def test_uplift_tree_visualization():
# Data generation
df, x_names = make_uplift_classification()

# Rename features for easy interpretation of visualization
x_names_new = ['feature_%s'%(i) for i in range(len(x_names))]
rename_dict = {x_names[i]:x_names_new[i] for i in range(len(x_names))}
x_names_new = ["feature_%s" % (i) for i in range(len(x_names))]
rename_dict = {x_names[i]: x_names_new[i] for i in range(len(x_names))}
df = df.rename(columns=rename_dict)
x_names = x_names_new

df.head()

df = df[df['treatment_group_key'].isin(['control','treatment1'])]
df = df[df["treatment_group_key"].isin(["control", "treatment1"])]

# Split data to training and testing samples for model validation (next section)
df_train, df_test = train_test_split(df, test_size=0.2, random_state=111)

# Train uplift tree
uplift_model = UpliftTreeClassifier(max_depth = 4, min_samples_leaf = 200, min_samples_treatment = 50, n_reg = 100, evaluationFunction='KL', control_name='control')
uplift_model = UpliftTreeClassifier(
max_depth=4,
min_samples_leaf=200,
min_samples_treatment=50,
n_reg=100,
evaluationFunction="KL",
control_name="control",
)

uplift_model.fit(df_train[x_names].values,
treatment=df_train['treatment_group_key'].values,
y=df_train['conversion'].values)
uplift_model.fit(
df_train[x_names].values,
treatment=df_train["treatment_group_key"].values,
y=df_train["conversion"].values,
)

# Plot uplift tree
graph = uplift_tree_plot(uplift_model.fitted_uplift_tree,x_names)
graph = uplift_tree_plot(uplift_model.fitted_uplift_tree, x_names)
graph.create_png()

0 comments on commit b9eb51e

Please sign in to comment.