From 3d9d49bd0521c3b2fc7b8e092a422297109c1c9c Mon Sep 17 00:00:00 2001 From: Roland Stevenson Date: Wed, 23 Aug 2023 10:33:07 +0000 Subject: [PATCH] linted with black --- tests/test_uplift_trees.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index c5372903..8da4eda4 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -250,32 +250,40 @@ def getNonleafCount(node): # shouldn't be larger than the number of non-leaf node assert num_non_zero_imp_features <= num_non_leaf_nodes -def test_uplift_tree_visualization(): +def test_uplift_tree_visualization(): # Data generation df, x_names = make_uplift_classification() # Rename features for easy interpretation of visualization - x_names_new = ['feature_%s'%(i) for i in range(len(x_names))] - rename_dict = {x_names[i]:x_names_new[i] for i in range(len(x_names))} + x_names_new = ["feature_%s" % (i) for i in range(len(x_names))] + rename_dict = {x_names[i]: x_names_new[i] for i in range(len(x_names))} df = df.rename(columns=rename_dict) x_names = x_names_new df.head() - df = df[df['treatment_group_key'].isin(['control','treatment1'])] + df = df[df["treatment_group_key"].isin(["control", "treatment1"])] # Split data to training and testing samples for model validation (next section) df_train, df_test = train_test_split(df, test_size=0.2, random_state=111) # Train uplift tree - uplift_model = UpliftTreeClassifier(max_depth = 4, min_samples_leaf = 200, min_samples_treatment = 50, n_reg = 100, evaluationFunction='KL', control_name='control') + uplift_model = UpliftTreeClassifier( + max_depth=4, + min_samples_leaf=200, + min_samples_treatment=50, + n_reg=100, + evaluationFunction="KL", + control_name="control", + ) - uplift_model.fit(df_train[x_names].values, - treatment=df_train['treatment_group_key'].values, - y=df_train['conversion'].values) + uplift_model.fit( + df_train[x_names].values, + treatment=df_train["treatment_group_key"].values, + y=df_train["conversion"].values, + ) # Plot uplift tree - graph = uplift_tree_plot(uplift_model.fitted_uplift_tree,x_names) + graph = uplift_tree_plot(uplift_model.fitted_uplift_tree, x_names) graph.create_png() -