-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
test_time_series_utils_plots.py
102 lines (83 loc) · 3.45 KB
/
test_time_series_utils_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Module to test time_series plotting functionality
"""
from typing import List
import pandas as pd # type: ignore
import pytest
from time_series_test_utils import _ALL_PLOTS
from pycaret.internal.plots.utils.time_series import (
ALLOWED_PLOT_DATA_TYPES,
MULTIPLE_PLOT_TYPES_ALLOWED_AT_ONCE,
_get_data_types_to_plot,
_reformat_dataframes_for_plots,
)
pytestmark = pytest.mark.filterwarnings("ignore::UserWarning")
##########################
# Tests Start Here ####
##########################
@pytest.mark.parametrize("plot", _ALL_PLOTS)
def test_get_data_types_to_plot(plot):
"""_summary_"""
if plot is not None:
############################################################
# 1. Nothing requested explicitly - returns defaults ----
############################################################
returned_val = _get_data_types_to_plot(plot=plot)
expected = [ALLOWED_PLOT_DATA_TYPES.get(plot)[0]]
assert isinstance(returned_val, List)
assert returned_val == expected
#####################################
# 2. Allowed values requested ----
#####################################
data_types_requested = ALLOWED_PLOT_DATA_TYPES.get(plot)
returned_val = _get_data_types_to_plot(
plot=plot, data_types_requested=data_types_requested
)
assert isinstance(returned_val, List)
accepts_multiple = MULTIPLE_PLOT_TYPES_ALLOWED_AT_ONCE.get(plot)
if accepts_multiple:
# 2A. Multiple data types can be plotted at once ----
assert returned_val == data_types_requested
else:
# 2B. Only one data type can be plotted at once ----
assert returned_val == [data_types_requested[0]]
######################################
# 3. Incorrect value requested ----
######################################
with pytest.raises(ValueError) as errmsg:
_ = _get_data_types_to_plot(plot=plot, data_types_requested="wrong")
# Capture Error message
exceptionmsg = errmsg.value.args[0]
# Check exact error received
assert (
"No data to plot. Please check to make sure that you have requested "
"an allowed data type for plot" in exceptionmsg
)
def test_reformat_dataframes_for_plots():
"""Tests for _reformat_dataframes_for_plots"""
df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
df2 = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
input_dfs = [df1, df2]
labels_suffix = ["original", "imputed"]
expected_cols = [
["a (original)", "a (imputed)"],
["b (original)", "b (imputed)"],
["c (original)", "c (imputed)"],
]
# 1. Correct Working ----
output_dfs = _reformat_dataframes_for_plots(
data=input_dfs, labels_suffix=labels_suffix
)
assert isinstance(output_dfs, List)
for item, expected_cols in zip(output_dfs, expected_cols):
assert isinstance(item, pd.DataFrame)
assert item.columns.to_list() == expected_cols
# Error raised ----
with pytest.raises(ValueError) as errmsg:
labels_suffix = ["original"]
output_dfs = _reformat_dataframes_for_plots(
data=input_dfs, labels_suffix=labels_suffix
)
# Capture Error message
exceptionmsg = errmsg.value.args[0]
# Check exact error received
assert "does not match the number of input dataframes" in exceptionmsg