-
Notifications
You must be signed in to change notification settings - Fork 3
/
engine_finetune.py
147 lines (116 loc) · 4.92 KB
/
engine_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import math
import sys
from typing import Iterable, Optional
import paddle
import util.misc as misc
import util.lr_sched as lr_sched
def train_one_epoch(model,
criterion,
data_loader,
optimizer,
device,
epoch,
loss_scaler,
max_norm=0,
log_writer=None,
model_ema=None,
args=None):
model.train()
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter(
'lr', misc.SmoothedValue(
window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter(
'loss', misc.SmoothedValue(
window_size=1, fmt='{value:.4f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.clear_grad()
# if log_writer is not None:
# print('log_dir: {}'.format(log_writer.kwargs['log_dir']))
for data_iter_step, (
samples, targets
) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(
optimizer, data_iter_step / len(data_loader) + epoch, args)
with paddle.amp.auto_cast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
total_grad_norm = loss_scaler(
loss,
optimizer,
clip_grad=max_norm,
parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0 and (data_iter_step + 1) % print_freq == 0:
for tag, value in model.named_parameters():
if (value is not None) and (value.grad is not None):
tag = tag.replace('.', '/')
try:
log_writer.add_scalar('grads/'+tag, value.grad.detach().abs().mean().numpy(), epoch_1000x)
except:
print(f'{tag} is NaN or Inf')
try:
log_writer.add_scalar('grads/total_grad_norm', total_grad_norm.detach().numpy(), epoch_1000x)
except:
print('total_grad_norm is NaN or Inf')
if (data_iter_step + 1) % accum_iter == 0:
if model_ema is not None:
model_ema.update(model)
optimizer.clear_grad()
paddle.device.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('step/loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('step/lr', max_lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@paddle.no_grad()
def evaluate(data_loader, model, device):
criterion = paddle.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
# compute output
with paddle.amp.auto_cast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = misc.accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(
'* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(
top1=metric_logger.acc1,
top5=metric_logger.acc5,
losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}