-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
test_time_series_mlflow.py
76 lines (62 loc) · 2.08 KB
/
test_time_series_mlflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Module to test time_series "MLflow" functionality
"""
from pycaret.time_series import TSForecastingExperiment
##########################
# Tests Start Here ####
##########################
def test_mlflow_logging(load_pos_and_neg_data):
"""Tests the logging of MLFlow experiment"""
data = load_pos_and_neg_data
exp = TSForecastingExperiment()
exp.setup(
data=data,
fh=12,
session_id=42,
log_experiment=True,
experiment_name="ts_unit_test",
log_plots=True,
)
model = exp.create_model("naive")
_ = exp.tune_model(model)
_ = exp.compare_models(include=["naive", "ets"])
mlflow_logs = exp.get_logs()
# When running locally, there can be multiple experiments with the same name
# Just get he last one so that the asserts work (otherwise, the count of the
# various function calls will not match)
last_start = mlflow_logs["start_time"].max() # noqa: F841
last_experiment_usi = mlflow_logs.query("start_time == @last_start")[ # noqa: F841
"tags.USI"
].unique()[0]
num_create_models = len(
mlflow_logs.query(
"`tags.USI` == @last_experiment_usi & `tags.Source` == 'create_model'"
)
)
num_tune_models = len(
mlflow_logs.query(
"`tags.USI` == @last_experiment_usi &`tags.Source` == 'tune_model'"
)
)
num_compare_models = len(
mlflow_logs.query(
"`tags.USI` == @last_experiment_usi &`tags.Source` == 'compare_models'"
)
)
assert num_create_models == 1
assert num_tune_models == 1
assert num_compare_models == 2
def test_mlflow_log_setup(load_pos_and_neg_data):
"""Tests the logging of MLFlow for plots during setup"""
data = load_pos_and_neg_data
exp = TSForecastingExperiment()
exp.setup(
data=data,
fh=12,
session_id=42,
log_experiment=True,
experiment_name="ts_unit_test",
log_plots=True,
)
mlflow_logs = exp.get_logs()
num_setup = len(mlflow_logs.query("`tags.Source` == 'setup'"))
assert num_setup == 1