-
Notifications
You must be signed in to change notification settings - Fork 20
/
exemplar.py
63 lines (57 loc) · 2.34 KB
/
exemplar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Exemplar:
def __init__(self, max_size, total_cls):
self.val = {}
self.train = {}
self.cur_cls = 0
self.max_size = max_size
self.total_classes = total_cls
def update(self, cls_num, train, val):
train_x, train_y = train
val_x, val_y = val
assert self.cur_cls == len(list(self.val.keys()))
assert self.cur_cls == len(list(self.train.keys()))
cur_keys = list(set(val_y))
self.cur_cls += cls_num
total_store_num = self.max_size / self.cur_cls if self.cur_cls != 0 else max_size
train_store_num = int(total_store_num * 0.9)
val_store_num = int(total_store_num * 0.1)
for key, value in self.val.items():
self.val[key] = value[:val_store_num]
for key, value in self.train.items():
self.train[key] = value[:train_store_num]
for x, y in zip(val_x, val_y):
if y not in self.val:
self.val[y] = [x]
else:
if len(self.val[y]) < val_store_num:
self.val[y].append(x)
assert self.cur_cls == len(list(self.val.keys()))
for key, value in self.val.items():
assert len(self.val[key]) == val_store_num
for x, y in zip(train_x, train_y):
if y not in self.train:
self.train[y] = [x]
else:
if len(self.train[y]) < train_store_num:
self.train[y].append(x)
assert self.cur_cls == len(list(self.train.keys()))
for key, value in self.train.items():
assert len(self.train[key]) == train_store_num
def get_exemplar_train(self):
exemplar_train_x = []
exemplar_train_y = []
for key, value in self.train.items():
for train_x in value:
exemplar_train_x.append(train_x)
exemplar_train_y.append(key)
return exemplar_train_x, exemplar_train_y
def get_exemplar_val(self):
exemplar_val_x = []
exemplar_val_y = []
for key, value in self.val.items():
for val_x in value:
exemplar_val_x.append(val_x)
exemplar_val_y.append(key)
return exemplar_val_x, exemplar_val_y
def get_cur_cls(self):
return self.cur_cls