/
test_datasets.py
148 lines (110 loc) · 4.59 KB
/
test_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""This file contains tests for the `shap.datasets` module."""
import pytest
import shap
@pytest.mark.parametrize("n_points", [None, 12])
def test_imagenet50(n_points):
# test that fetch/download works fine
X, y = shap.datasets.imagenet50(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 50 if n_points is None else n_points
assert X.shape == (n_points, 224, 224, 3)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_california(n_points):
# test that fetch/download works fine
X, y = shap.datasets.california(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 20_640 if n_points is None else n_points
assert X.shape == (n_points, 8)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_linnerud(n_points):
# test that fetch/download works fine
X, y = shap.datasets.linnerud(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 20 if n_points is None else n_points
assert X.shape == (n_points, 3)
assert y.shape == (n_points, 3)
@pytest.mark.parametrize("n_points", [None, 12])
def test_imdb(n_points):
# test that fetch/download works fine
X, y = shap.datasets.imdb(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 25_000 if n_points is None else n_points
assert len(X) == n_points
assert len(y) == n_points
@pytest.mark.parametrize("n_points", [None, 12])
def test_diabetes(n_points):
# test that fetch/download works fine
X, y = shap.datasets.diabetes(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 442 if n_points is None else n_points
assert X.shape == (n_points, 10)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_iris(n_points):
# test that fetch/download works fine
X, y = shap.datasets.iris(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 150 if n_points is None else n_points
assert X.shape == (n_points, 4)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_adult(n_points):
# test that fetch/download works fine
X, y = shap.datasets.adult(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 32_561 if n_points is None else n_points
assert X.shape == (n_points, 12)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_nhanesi(n_points):
# test that fetch/download works fine
X, y = shap.datasets.nhanesi(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 14_264 if n_points is None else n_points
assert X.shape == (n_points, 79)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [100, 2_000])
def test_corrgroups60(n_points):
# test that fetch/download works fine
X, y = shap.datasets.corrgroups60(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
assert X.shape == (n_points, 60)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [100, 2_000])
def test_independentlinear60(n_points):
# test that fetch/download works fine
X, y = shap.datasets.independentlinear60(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
assert X.shape == (n_points, 60)
assert y.shape == (n_points,)
@pytest.mark.parametrize("n_points", [None, 12])
def test_a1a(n_points):
# test that fetch/download works fine
X, y = shap.datasets.a1a(n_points=n_points)
# check the shape of the result
# check that the n_points parameter samples the dataset
n_points = 1_605 if n_points is None else n_points
assert X.shape == (n_points, 119)
assert y.shape == (n_points,)
def test_rank():
# test that fetch/download works fine
X1, y1, X2, y2, q1, q2 = shap.datasets.rank()
# check the shape of the result
assert X1.shape == (3_005, 300)
assert y1.shape == (3_005,)
assert X2.shape == (768, 300)
assert y2.shape == (768,)
assert q1.shape == (201,)
assert q2.shape == (50,)