forked from GMvandeVen/continual-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
349 lines (300 loc) · 18.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import torch
from torch import optim
from torch.utils.data import ConcatDataset
import numpy as np
import tqdm
import copy
import utils
import dgr
from continual_learner import ContinualLearner
def train_cl(model, train_datasets, replay_mode="none", scenario="class", classes_per_task=None,
iters=2000, batch_size=32, collate_fn=None, visualize=True,
generator=None, gen_iters=0, gen_loss_cbs=list(), loss_cbs=list(), eval_cbs=list(), sample_cbs=list()):
'''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].
[model] <nn.Module> main model to optimize across all tasks
[train_datasets] <list> with for each task the training <DataSet>
[replay_mode] <str>, choice from "generative", "exact", "current", "offline" and "none"
[scenario] <str>, choice from "task", "domain" and "class"
[classes_per_task] <int>, # of classes per task
[iters] <int>, # of optimization-steps (i.e., # of batches) per task
[visualize] <bool>, whether all losses should be calculated for plotting (even if not used)
[generator] None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
[*_cbs] <list> of call-back functions to evaluate training-progress'''
# Set model in training-mode
model.train()
# Use cuda?
cuda = model._is_on_cuda()
device = model._device()
# Initiate possible sources for replay (no replay for 1st task)
previous_model = previous_scholar = previous_datasets = None
exact_replay = generative_replay = current_replay = False
# Register starting param-values (needed for "intelligent synapses").
if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
for n, p in model.named_parameters():
if p.requires_grad:
n = n.replace('.', '__')
model.register_buffer('{}_SI_prev_task'.format(n), p.data.clone())
# Loop over all tasks.
for task, train_dataset in enumerate(train_datasets, 1):
# Do not train if non-positive iterations
if iters <= 0:
return
# If offline replay-setting, create large database of all tasks so far
if replay_mode=="offline" and (not scenario=="task"):
train_dataset = ConcatDataset(train_datasets[:task])
# -but if "offline"+"task"-scenario: all tasks so far included in 'exact replay' & no current batch
if replay_mode=="offline" and scenario == "task":
exact_replay = True
####################################### MAIN MODEL #######################################
# Prepare <dicts> to store running importance estimates and param-values before update ("Synaptic Intelligence")
if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
W = {}
p_old = {}
for n, p in model.named_parameters():
if p.requires_grad:
n = n.replace('.', '__')
W[n] = p.data.clone().zero_()
p_old[n] = p.data.clone()
# Reset state of optimizer for every task (if requested)
if model.optim_type=="adam_reset":
model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
# Initialize # iters left on current data-loader(s)
iters_left = iters_left_previous = 1
if scenario=="task":
up_to_task = task if replay_mode=="offline" else task-1
iters_left_previous = [1]*up_to_task
data_loader_previous = [None]*up_to_task
# Define a tqdm progress bar
progress = tqdm.tqdm(range(1, iters+1))
# Loop over all iterations
for batch_index in progress:
# Update # iters left on current data-loader(s) and, if needed, create new one(s)
iters_left -= 1
if iters_left==0:
data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda, collate_fn=collate_fn,
drop_last=True))
iters_left = len(data_loader)
if exact_replay:
if scenario=="task":
up_to_task = task if replay_mode=="offline" else task-1
batch_size_replay = int(np.ceil(batch_size/up_to_task)) if (up_to_task>1) else batch_size
# -in incremental task learning scenario, need separate replay for each task
for task_id in range(up_to_task):
iters_left_previous[task_id] -= 1
if iters_left_previous[task_id]==0:
data_loader_previous[task_id] = iter(utils.get_data_loader(
train_datasets[task_id], batch_size_replay, cuda=cuda,
collate_fn=collate_fn, drop_last=True
))
iters_left_previous[task_id] = len(data_loader_previous[task_id])
else:
iters_left_previous -= 1
if iters_left_previous==0:
data_loader_previous = iter(utils.get_data_loader(ConcatDataset(previous_datasets),
batch_size, cuda=cuda,
collate_fn=collate_fn, drop_last=True))
iters_left_previous = len(data_loader_previous)
#####-----CURRENT BATCH-----#####
if replay_mode=="offline" and scenario=="task":
x = y = None
else:
x, y = next(data_loader) #--> sample training data of current task
y = y-classes_per_task*(task-1) if scenario=="task" else y #--> ITL: adjust y-targets to 'active range'
x, y = x.to(device), y.to(device) #--> transfer them to correct device
#####-----REPLAYED BATCH-----#####
if not exact_replay and not generative_replay and not current_replay:
x_ = y_ = scores_ = None #-> if no replay
##-->> Current Replay <<--##
if current_replay:
scores_ = None
if not scenario=="task":
# Use same as CURRENT BATCH to replay
x_ = x
y_ = y if ((model.replay_targets=="hard") or visualize) else None
# Get predicted "logits"/"scores" on replayed data (from previous model)
if (model.replay_targets=="soft") or visualize:
with torch.no_grad():
scores_ = previous_model(x_)
scores_ = scores_[:, :(classes_per_task * (task - 1))] if scenario=="class" else scores_
# --> ICL: zero probabilities will be added in the [utils.loss_fn_kd]-function
else:
if model.replay_targets=="hard":
raise NotImplementedError(
"'Current' replay with 'hard targets' not implemented for 'incremental task learning'."
)
# For each task to replay, use same [x] as in CURRENT BATCH
x_ = list()
for task_id in range(task-1):
x_.append(x)
# Get predicted "logits" on replayed data (from previous model)
if (model.replay_targets=="soft") or visualize:
scores_ = list()
for task_id in range(task-1):
with torch.no_grad():
scores_temp = previous_model(x_[task_id])
scores_.append(scores_temp[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))])
##-->> Exact Replay <<--##
if exact_replay:
scores_ = None
if not scenario=="task":
# Sample replayed training data, wrap in (cuda-)Variables
x_, y_ = next(data_loader_previous)
x_ = x_.to(device)
y_ = y_.to(device) if (model.replay_targets=="hard") or visualize else None
# Get predicted "logits"/"scores" on replayed data (from previous model)
if (model.replay_targets=="soft") or visualize:
with torch.no_grad():
scores_ = previous_model(x_)
scores_ = scores_[:, :(classes_per_task * (task - 1))] if scenario=="class" else scores_
# --> ICL: zero probabilities will be added in the [utils.loss_fn_kd]-function
else:
# Sample replayed training data, wrap in (cuda-)Variables and store in lists
x_ = list()
y_ = list()
up_to_task = task if replay_mode=="offline" else task-1
for task_id in range(up_to_task):
x_temp, y_temp = next(data_loader_previous[task_id])
x_.append(x_temp.to(device))
# -only keep [y_] if required (as otherwise unnecessary computations will be done)
if (model.replay_targets == "hard") or visualize:
y_temp = y_temp - (classes_per_task*task_id) #-> adjust y-targets to 'active range'
y_.append(y_temp.to(device))
else:
y_.append(None)
# Get predicted "logits" on replayed data (from previous model)
if ((model.replay_targets=="soft") or visualize) and (previous_model is not None):
scores_ = list()
for task_id in range(up_to_task):
with torch.no_grad():
scores_temp = previous_model(x_[task_id])
scores_.append(scores_temp[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))])
##-->> Generative Replay <<--##
if generative_replay:
if not scenario=="task":
# Which classes could be predicted (=[allowed_predictions])?
allowed_predictions = None if scenario=="domain" else list(range(classes_per_task*(task-1)))
# Sample replayed data, along with their predicted "logits" (both from previous model / scholar)
sample_model = previous_model if generator is None else previous_scholar
x_, y_, scores_ = sample_model.sample(batch_size, allowed_predictions=allowed_predictions,
return_scores=True)
# -only keep predicted y/scores if required (as otherwise unnecessary computations will be done)
y_ = y_ if ((model.replay_targets=="hard") or visualize) else None
scores_ = scores_ if ((model.replay_targets=="soft") or visualize) else None
else:
x_ = list()
y_ = list()
scores_ = list()
# For each previous task, list which classes could be predicted
allowed_pred_list = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)]
for prev_task_id in range(1, task):
# Sample replayed data, along with their predicted "logits" (both from previous model / scholar)
sample_model = previous_model if generator is None else previous_scholar
batch_size_replay = int(np.ceil(batch_size / (task-1))) if (task > 2) else batch_size
x_temp, y_temp, scores_temp = sample_model.sample(
batch_size_replay, allowed_predictions=allowed_pred_list[prev_task_id-1],
return_scores=True,
)
x_.append(x_temp)
# -only keep [y_] / [scores_] if required (as otherwise unnecessary computations will be done)
y_.append(y_temp if (model.replay_targets=="hard" or visualize) else None)
scores_.append(scores_temp if (model.replay_targets=="soft" or visualize) else None)
# Find [active_classes]
active_classes = None #-> for "domain"-sce, always all classes are active
if scenario=="task":
# -for "task"-sce, create <list> with for all tasks so far a <list> with the active classes
active_classes = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)]
elif scenario=="class":
# -for "class"-sce, create one <list> with active classes of all tasks so far
active_classes = list(range(classes_per_task*task))
# Train the model with this batch
loss_dict = model.train_a_batch(
x, y, x_=x_, y_=y_, scores_=scores_, active_classes=active_classes, task=task, rnt = 1./task,
)
# Update running parameter importance estimates in W
if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
for n, p in model.named_parameters():
if p.requires_grad:
n = n.replace('.', '__')
if p.grad is not None:
W[n].add_(-p.grad*(p.detach()-p_old[n]))
p_old[n] = p.detach().clone()
# Fire callbacks (for visualization of training-progress / evaluating performance after each task)
for loss_cb in loss_cbs:
if loss_cb is not None:
loss_cb(progress, batch_index, loss_dict, task=task)
for eval_cb in eval_cbs:
if eval_cb is not None:
eval_cb(model, batch_index, task=task)
if model.label=="VAE":
for sample_cb in sample_cbs:
if sample_cb is not None:
sample_cb(model, batch_index, task=task)
####################################### GENERATOR #######################################
if generator is not None:
# Reset state of optimizer for every task (if requested)
if generator.optim_type=="adam_reset":
generator.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
# Initialize number of iters left on current data-loader(s)
iters_left = iters_left_previous = 1
# Define a tqdm progress bar
progress = tqdm.tqdm(range(1, gen_iters+1))
# Loop over all iterations.
for batch_index in progress:
# Update # iters left on current data-loader(s) and, if needed, create new one(s)
iters_left -= 1
if iters_left == 0:
data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda,
collate_fn=collate_fn, drop_last=True))
iters_left = len(data_loader)
if exact_replay:
iters_left_previous -= 1
if iters_left_previous == 0:
data_loader_previous = iter(utils.get_data_loader(ConcatDataset(previous_datasets),
batch_size, cuda=cuda,
collate_fn=collate_fn, drop_last=True))
iters_left_previous = len(data_loader_previous)
# Sample training data of current task
x, _ = next(data_loader)
x = x.to(device)
# Sample replayed training data
if exact_replay:
x_, _ = next(data_loader_previous)
x_ = x_.to(device)
elif generative_replay:
x_, _ = previous_scholar.sample(batch_size)
elif current_replay:
x_ = x
else:
x_ = None
# Train the generator with this batch
loss_dict = generator.train_a_batch(x, y=None, x_=x_, y_=None, rnt=1./task)
# Fire callbacks on each iteration
for loss_cb in gen_loss_cbs:
if loss_cb is not None:
loss_cb(progress, batch_index, loss_dict, task=task)
for sample_cb in sample_cbs:
if sample_cb is not None:
sample_cb(generator, batch_index, task=task)
##----------> UPON FINISHING EACH TASK...
# EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
if isinstance(model, ContinualLearner) and (model.ewc_lambda>0 or visualize):
allowed_classes = list(
range(classes_per_task*(task-1), classes_per_task*task)
) if scenario=="task" else (list(range(classes_per_task*task)) if scenario=="class" else None)
model.estimate_fisher(train_dataset, allowed_classes=allowed_classes, collate_fn=collate_fn)
# SI: calculate and update the normalized path integral
if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
model.update_omega(W, model.epsilon)
# REPLAY: update source for replay
previous_model = copy.deepcopy(model)
previous_model.eval()
if generator is not None:
scholar = dgr.Scholar(generator=generator, solver=model)
previous_scholar = copy.deepcopy(scholar)
if replay_mode=='generative':
generative_replay = True
elif replay_mode=='exact':
previous_datasets = train_datasets[:task]
exact_replay = True
elif replay_mode=='current':
current_replay = True