-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathtest_rotations.py
109 lines (78 loc) · 3.01 KB
/
test_rotations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import copy
import numpy as np
import pytest
from continuum.scenarios import Rotations
from continuum.datasets import MNIST, CIFAR100, InMemoryDataset
from tests.test_classorder import InMemoryDatasetTest
DATA_PATH = os.environ.get("CONTINUUM_DATA_PATH")
@pytest.fixture
def numpy_data():
nb_classes = 6
nb_data = 100
x_train = []
y_train = []
for i in range(nb_classes):
x_train.append(np.ones((nb_data, 4, 4, 3), dtype=np.uint8) * i)
y_train.append(np.ones(nb_data) * i)
x_train = np.concatenate(x_train)
y_train = np.concatenate(y_train)
x_test = np.copy(x_train)
y_test = np.copy(y_train)
return (x_train, y_train.astype(int)), (x_test, y_test.astype(int))
'''
Test the initialization with three tasks
'''
def test_init(numpy_data):
train, test = numpy_data
dummy = InMemoryDatasetTest(*train)
Trsf_0 = 0
Trsf_1 = (15, 20)
Trsf_2 = 45
list_degrees = [Trsf_0, Trsf_1, Trsf_2]
scenario = Rotations(cl_dataset=dummy, nb_tasks=3, list_degrees=list_degrees)
for task_id, train_dataset in enumerate(scenario):
continue
@pytest.mark.parametrize("shared_label_space", [True, False])
def test_shared_labels(numpy_data, shared_label_space):
train, test = numpy_data
dummy = InMemoryDatasetTest(*train)
list_degrees = [0, 15, 45]
nb_classes = 6
scenario = Rotations(cl_dataset=dummy, nb_tasks=3, list_degrees=list_degrees, shared_label_space=shared_label_space)
for task_id, taskset in enumerate(scenario):
classes = taskset.get_classes()
if shared_label_space:
assert (classes == np.arange(nb_classes)).all()
else:
assert (classes == np.arange(nb_classes) + task_id * nb_classes).all(), task_id
def test_fail_init(numpy_data):
train, test = numpy_data
dummy = InMemoryDatasetTest(*train)
Trsf_0 = 2
Trsf_1 = (15, 20, 25) # non sense
Trsf_2 = 45
list_degrees = [Trsf_0, Trsf_1, Trsf_2]
# should detect that a transformation is non-sens in the list
with pytest.raises(ValueError):
Rotations(cl_dataset=dummy, nb_tasks=3, list_degrees=list_degrees)
@pytest.mark.slow
@pytest.mark.parametrize("shared_label_space", [True, False])
@pytest.mark.parametrize("dataset", [MNIST, CIFAR100])
def test_with_dataset(dataset, shared_label_space):
dataset = dataset(data_path=DATA_PATH, download=True, train=True)
list_degrees = [0, 45, 90]
scenario = Rotations(cl_dataset=dataset,
nb_tasks=3,
list_degrees=list_degrees,
shared_label_space=shared_label_space)
assert len(scenario) == 3
previous_classes = None
for task_id, taskset in enumerate(scenario):
classes = taskset.get_classes()
if task_id > 0:
if shared_label_space:
assert (classes == previous_classes).all()
else:
assert (classes == previous_classes + len(classes)).all(), classes
previous_classes = classes