-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
test_classification_plots.py
52 lines (39 loc) · 1.37 KB
/
test_classification_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pandas as pd
import pycaret.classification
import pycaret.datasets
def test_plot():
# loading dataset
data = pycaret.datasets.get_data("juice")
assert isinstance(data, pd.DataFrame)
# init setup
pycaret.classification.setup(
data,
target="Purchase",
log_experiment=True,
log_plots=True,
html=False,
session_id=123,
fold=2,
n_jobs=1,
)
model = pycaret.classification.create_model("rf", max_depth=2, n_estimators=5)
exp = pycaret.classification.ClassificationExperiment()
available_plots = exp._available_plots
for plot in available_plots:
pycaret.classification.plot_model(model, plot=plot, use_train_data=False)
pycaret.classification.plot_model(model, plot=plot, use_train_data=True)
models = [
pycaret.classification.create_model("et"),
pycaret.classification.create_model("xgboost"),
]
# no pfi due to dependency hell
available_shap = ["summary", "correlation", "reason", "pdp", "msa"]
for model in models:
for plot in available_shap:
pycaret.classification.interpret_model(model, plot=plot)
pycaret.classification.interpret_model(
model, plot=plot, X_new_sample=data.drop("Purchase", axis=1).iloc[:10]
)
assert 1 == 1
if __name__ == "__main__":
test_plot()