-
Notifications
You must be signed in to change notification settings - Fork 14
/
trainer.py
399 lines (321 loc) · 11.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""
basic trainer
"""
import time
import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import utils as utils
import numpy as np
import torch
__all__ = ["Trainer"]
class Trainer(object):
"""
trainer for training network, use SGD
"""
def __init__(self, model, model_teacher, generator, lr_master_S, lr_master_G,
train_loader, test_loader, settings, logger, tensorboard_logger=None,
opt_type="SGD", optimizer_state=None, run_count=0):
"""
init trainer
"""
self.settings = settings
self.model = utils.data_parallel(
model, self.settings.nGPU, self.settings.GPU)
self.model_teacher = utils.data_parallel(
model_teacher, self.settings.nGPU, self.settings.GPU)
self.generator = utils.data_parallel(
generator, self.settings.nGPU, self.settings.GPU)
self.train_loader = train_loader
self.test_loader = test_loader
self.tensorboard_logger = tensorboard_logger
self.criterion = nn.CrossEntropyLoss().cuda()
self.bce_logits = nn.BCEWithLogitsLoss().cuda()
self.MSE_loss = nn.MSELoss().cuda()
self.lr_master_S = lr_master_S
self.lr_master_G = lr_master_G
self.opt_type = opt_type
if opt_type == "SGD":
self.optimizer_S = torch.optim.SGD(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
momentum=self.settings.momentum,
weight_decay=self.settings.weightDecay,
nesterov=True,
)
elif opt_type == "RMSProp":
self.optimizer_S = torch.optim.RMSprop(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
eps=1.0,
weight_decay=self.settings.weightDecay,
momentum=self.settings.momentum,
alpha=self.settings.momentum
)
elif opt_type == "Adam":
self.optimizer_S = torch.optim.Adam(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
eps=1e-5,
weight_decay=self.settings.weightDecay
)
else:
assert False, "invalid type: %d" % opt_type
if optimizer_state is not None:
self.optimizer_S.load_state_dict(optimizer_state)\
self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.settings.lr_G,
betas=(self.settings.b1, self.settings.b2))
self.logger = logger
self.run_count = run_count
self.scalar_info = {}
self.mean_list = []
self.var_list = []
self.teacher_running_mean = []
self.teacher_running_var = []
self.save_BN_mean = []
self.save_BN_var = []
self.fix_G = False
def update_lr(self, epoch):
"""
update learning rate of optimizers
:param epoch: current training epoch
"""
lr_S = self.lr_master_S.get_lr(epoch)
lr_G = self.lr_master_G.get_lr(epoch)
# update learning rate of model optimizer
for param_group in self.optimizer_S.param_groups:
param_group['lr'] = lr_S
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr_G
def loss_fn_kd(self, output, labels, teacher_outputs):
"""
Compute the knowledge-distillation (KD) loss given outputs, labels.
"Hyperparameters": temperature and alpha
NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
and student expects the input tensor to be log probabilities! See Issue #2
"""
criterion_d = nn.CrossEntropyLoss().cuda()
kdloss = nn.KLDivLoss().cuda()
alpha = self.settings.alpha
T = self.settings.temperature
a = F.log_softmax(output / T, dim=1)
b = F.softmax(teacher_outputs / T, dim=1)
c = (alpha * T * T)
d = criterion_d(output, labels)
KD_loss = kdloss(a,b)*c + d
return KD_loss
def forward(self, images, teacher_outputs, labels=None):
"""
forward propagation
"""
# forward and backward and optimize
output,output_1 = self.model(images, True)
if labels is not None:
loss = self.loss_fn_kd(output, labels, teacher_outputs)
return output, loss
else:
return output, None
def backward_G(self, loss_G):
"""
backward propagation
"""
self.optimizer_G.zero_grad()
loss_G.backward()
self.optimizer_G.step()
def backward_S(self, loss_S):
"""
backward propagation
"""
self.optimizer_S.zero_grad()
loss_S.backward()
self.optimizer_S.step()
def backward(self, loss):
"""
backward propagation
"""
self.optimizer_G.zero_grad()
self.optimizer_S.zero_grad()
loss.backward()
self.optimizer_G.step()
self.optimizer_S.step()
def hook_fn_forward(self,module, input, output):
input = input[0]
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
self.mean_list.append(mean)
self.var_list.append(var)
self.teacher_running_mean.append(module.running_mean)
self.teacher_running_var.append(module.running_var)
def hook_fn_forward_saveBN(self,module, input, output):
self.save_BN_mean.append(module.running_mean.cpu())
self.save_BN_var.append(module.running_var.cpu())
def train(self, epoch):
"""
training
"""
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
fp_acc = utils.AverageMeter()
iters = 200
self.update_lr(epoch)
self.model.eval()
self.model_teacher.eval()
self.generator.train()
start_time = time.time()
end_time = start_time
if epoch==0:
for m in self.model_teacher.modules():
if isinstance(m, nn.BatchNorm2d):
m.register_forward_hook(self.hook_fn_forward)
for i in range(iters):
start_time = time.time()
data_time = start_time - end_time
z = Variable(torch.randn(self.settings.batchSize, self.settings.latent_dim)).cuda()
# Get labels ranging from 0 to n_classes for n rows
labels = Variable(torch.randint(0, self.settings.nClasses, (self.settings.batchSize,))).cuda()
z = z.contiguous()
labels = labels.contiguous()
images = self.generator(z, labels)
self.mean_list.clear()
self.var_list.clear()
output_teacher_batch, output_teacher_1 = self.model_teacher(images, out_feature = True)
# One hot loss
loss_one_hot = self.criterion(output_teacher_batch, labels)
# BN statistic loss
BNS_loss = torch.zeros(1).cuda()
for num in range(len(self.mean_list)):
BNS_loss += self.MSE_loss(self.mean_list[num], self.teacher_running_mean[num]) + self.MSE_loss(
self.var_list[num], self.teacher_running_var[num])
BNS_loss = BNS_loss / len(self.mean_list)
# loss of Generator
loss_G = loss_one_hot + 0.1 * BNS_loss
self.backward_G(loss_G)
output, loss_S = self.forward(images.detach(), output_teacher_batch.detach(), labels)
if epoch>= self.settings.warmup_epochs:
self.backward_S(loss_S)
single_error, single_loss, single5_error = utils.compute_singlecrop(
outputs=output, labels=labels,
loss=loss_S, top5_flag=True, mean_flag=True)
top1_error.update(single_error, images.size(0))
top1_loss.update(single_loss, images.size(0))
top5_error.update(single5_error, images.size(0))
end_time = time.time()
gt = labels.data.cpu().numpy()
d_acc = np.mean(np.argmax(output_teacher_batch.data.cpu().numpy(), axis=1) == gt)
fp_acc.update(d_acc)
print(
"[Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%] [G loss: %f] [One-hot loss: %f] [BNS_loss:%f] [S loss: %f] "
% (epoch + 1, self.settings.nEpochs, i+1, iters, 100 * fp_acc.avg, loss_G.item(), loss_one_hot.item(), BNS_loss.item(),
loss_S.item())
)
self.scalar_info['accuracy every epoch'] = 100 * d_acc
self.scalar_info['G loss every epoch'] = loss_G
self.scalar_info['One-hot loss every epoch'] = loss_one_hot
self.scalar_info['S loss every epoch'] = loss_S
self.scalar_info['training_top1error'] = top1_error.avg
self.scalar_info['training_top5error'] = top5_error.avg
self.scalar_info['training_loss'] = top1_loss.avg
if self.tensorboard_logger is not None:
for tag, value in list(self.scalar_info.items()):
self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
self.scalar_info = {}
return top1_error.avg, top1_loss.avg, top5_error.avg
def test(self, epoch):
"""
testing
"""
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
self.model.eval()
self.model_teacher.eval()
iters = len(self.test_loader)
start_time = time.time()
end_time = start_time
with torch.no_grad():
for i, (images, labels) in enumerate(self.test_loader):
start_time = time.time()
labels = labels.cuda()
images = images.cuda()
output = self.model(images)
loss = torch.ones(1)
self.mean_list.clear()
self.var_list.clear()
single_error, single_loss, single5_error = utils.compute_singlecrop(
outputs=output, loss=loss,
labels=labels, top5_flag=True, mean_flag=True)
top1_error.update(single_error, images.size(0))
top1_loss.update(single_loss, images.size(0))
top5_error.update(single5_error, images.size(0))
end_time = time.time()
print(
"[Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%]"
% (epoch + 1, self.settings.nEpochs, i + 1, iters, (100.00-top1_error.avg))
)
self.scalar_info['testing_top1error'] = top1_error.avg
self.scalar_info['testing_top5error'] = top5_error.avg
self.scalar_info['testing_loss'] = top1_loss.avg
if self.tensorboard_logger is not None:
for tag, value in self.scalar_info.items():
self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
self.scalar_info = {}
self.run_count += 1
return top1_error.avg, top1_loss.avg, top5_error.avg
def test_teacher(self, epoch):
"""
testing
"""
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
self.model_teacher.eval()
iters = len(self.test_loader)
start_time = time.time()
end_time = start_time
with torch.no_grad():
for i, (images, labels) in enumerate(self.test_loader):
start_time = time.time()
data_time = start_time - end_time
labels = labels.cuda()
if self.settings.tenCrop:
image_size = images.size()
images = images.view(
image_size[0] * 10, image_size[1] / 10, image_size[2], image_size[3])
images_tuple = images.split(image_size[0])
output = None
for img in images_tuple:
if self.settings.nGPU == 1:
img = img.cuda()
img_var = Variable(img, volatile=True)
temp_output, _ = self.forward(img_var)
if output is None:
output = temp_output.data
else:
output = torch.cat((output, temp_output.data))
single_error, single_loss, single5_error = utils.compute_tencrop(
outputs=output, labels=labels)
else:
if self.settings.nGPU == 1:
images = images.cuda()
output = self.model_teacher(images)
loss = torch.ones(1)
self.mean_list.clear()
self.var_list.clear()
single_error, single_loss, single5_error = utils.compute_singlecrop(
outputs=output, loss=loss,
labels=labels, top5_flag=True, mean_flag=True)
#
top1_error.update(single_error, images.size(0))
top1_loss.update(single_loss, images.size(0))
top5_error.update(single5_error, images.size(0))
end_time = time.time()
iter_time = end_time - start_time
print(
"Teacher network: [Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%]"
% (epoch + 1, self.settings.nEpochs, i + 1, iters, (100.00 - top1_error.avg))
)
self.run_count += 1
return top1_error.avg, top1_loss.avg, top5_error.avg