-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathtest_classes_incremental.py
114 lines (85 loc) · 3.63 KB
/
test_classes_incremental.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import pytest
from torch.utils.data import DataLoader
from continuum.scenarios import ClassIncremental
from continuum.datasets import MNIST, CIFAR10, CIFAR100, KMNIST, FashionMNIST, InMemoryDataset
from torchvision.transforms import transforms
# yapf: disable
@pytest.mark.slow
@pytest.mark.parametrize("dataset, increment", [(MNIST, 5),
(KMNIST, 2),
(FashionMNIST, 1),
(CIFAR10, 2),
(CIFAR100, 10)])
def test_with_dataset_simple_increment(tmpdir, dataset, increment):
dataset = dataset(data_path=tmpdir, download=True, train=True)
scenario = ClassIncremental(cl_dataset=dataset,
increment=increment,
transformations=[transforms.ToTensor()]
)
for task_id, taskset in enumerate(scenario):
classes = taskset.get_classes()
assert len(classes) == increment
# check if there is continuity in classes by default
assert len(classes) == (classes.max() - classes.min() + 1)
@pytest.mark.slow
@pytest.mark.parametrize("dataset, increment", [(MNIST, [5, 1, 1, 3]),
(KMNIST, [2, 2, 4, 2]),
(FashionMNIST, [1, 2, 1, 2, 1, 2, 1]),
(CIFAR10, [2, 2, 2, 2, 2]),
(CIFAR100, [50, 10, 20, 20])])
def test_with_dataset_composed_increment(tmpdir, dataset, increment):
dataset = dataset(data_path=tmpdir, download=True, train=True)
scenario = ClassIncremental(cl_dataset=dataset,
increment=increment,
transformations=[transforms.ToTensor()]
)
for task_id, taskset in enumerate(scenario):
classes = taskset.get_classes()
assert len(classes) == increment[task_id]
# check if there is continuity in classes by default
assert len(classes) == (classes.max() - classes.min() + 1)
NB_CLASSES = 10
@pytest.fixture
def fake_data():
x_train = np.random.randint(0, 255, size=(20, 32, 32, 3))
y_train = []
for i in range(NB_CLASSES):
y_train.append(np.ones(2) * i)
y_train = np.concatenate(y_train)
return InMemoryDataset(x_train, y_train)
@pytest.mark.parametrize("class_order", [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[3, 9, 2, 0, 4, 5, 9, 7, 6, 1],
])
def test_taskid(fake_data, class_order):
scenario = ClassIncremental(
cl_dataset=fake_data,
increment=2
)
assert scenario.nb_samples == 20
for task_id, taskset in enumerate(scenario):
loader = DataLoader(taskset, batch_size=32)
for x, y, t in loader:
assert t[0].item() == task_id
assert (t == task_id).all()
def test_nb_classes(fake_data):
scenario = ClassIncremental(
cl_dataset=fake_data,
increment=2
)
assert scenario.nb_samples == 20
assert scenario.nb_classes == NB_CLASSES
assert (scenario.classes == np.arange(NB_CLASSES)).all()
def test_list_transforms(fake_data):
nb_tasks = 5
list_trsfs = []
for _ in range(nb_tasks - 1):
list_trsfs.append([transforms.RandomAffine(degrees=[0, 90])])
# should fail since nb_task != len(list_trsfs)
with pytest.raises(ValueError) as e:
scenario = ClassIncremental(
cl_dataset=fake_data,
increment=2,
transformations=list_trsfs
)