/
engine.py
185 lines (162 loc) · 6.66 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch.nn as nn
from core.engine import get_lr
from core.utils import *
from models import LIQ_wn_qsam
from utils.bypass_bn import disable_running_stats, enable_running_stats
def set_first_forward(model):
for n, m in model.named_modules():
if isinstance(m, (LIQ_wn_qsam.QConv2d, LIQ_wn_qsam.QLinear,),):
m.set_first_forward()
def set_layer_first_forward(model, layer_name):
for n, m in model.named_modules():
if (
isinstance(m, (LIQ_wn_qsam.QConv2d, LIQ_wn_qsam.QLinear,),)
and n in layer_name
):
m.set_first_forward()
def set_second_forward(model):
for n, m in model.named_modules():
if isinstance(m, (LIQ_wn_qsam.QConv2d, LIQ_wn_qsam.QLinear,),):
m.set_second_forward()
def set_layer_second_forward(model, layer_name):
for n, m in model.named_modules():
if (
isinstance(m, (LIQ_wn_qsam.QConv2d, LIQ_wn_qsam.QLinear,),)
and n in layer_name
):
m.set_second_forward()
def train(
model,
train_loader,
criterion,
optimizer,
minimizer,
scheduler,
device,
logger,
tensorboard_logger,
epoch,
args,
):
"""
Train one epoch
:param epoch: index of epoch
"""
metric_logger = MetricLogger(logger=logger, delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", SmoothedValue(window_size=10, fmt="{value}"))
model.train()
header = "Epoch: [{}]".format(epoch)
for image, target in metric_logger.log_every(
train_loader, args.print_frequency, header
):
start_time = time.time()
image, target = image.to(device), target.to(device)
# Ascent Step
model.require_backward_grad_sync = False
model.require_forward_param_sync = True
# enable_running_stats(model)
output = model(image)
loss = criterion(output, target)
loss.backward()
minimizer.ascent_step()
# descent step
model.require_backward_grad_sync = True
model.require_forward_param_sync = False
if "QSAM" in args.opt_type or "QASAM" in args.opt_type:
set_second_forward(model)
# disable_running_stats(model)
criterion(model(image), target).backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
minimizer.descent_step()
if "QSAM" in args.opt_type or "QASAM" in args.opt_type:
set_first_forward(model)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(
batch_size * args.world_size / (time.time() - start_time)
)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
scheduler.step()
lr = get_lr(optimizer)
logger.info("Change Learning rate: {}".format(lr))
train_error = 100 - metric_logger.acc1.global_avg
train_loss = metric_logger.loss.global_avg
train5_error = 100 - metric_logger.acc5.global_avg
if tensorboard_logger is not None:
tensorboard_logger.add_scalar("train_top1_error", train_error, epoch)
tensorboard_logger.add_scalar("train_top5_error", train5_error, epoch)
tensorboard_logger.add_scalar("train_loss", train_loss, epoch)
tensorboard_logger.add_scalar("lr", lr, epoch)
weight_eps_names = [
"epsilon",
"tw_epsilon_norm",
"normalized_tw_epsilon_norm",
"weight_clip_value_epsilon",
"weight_clip_value_tw_epsilon_norm",
"weight_clip_value_normalized_tw_epsilon_norm",
"activation_clip_value_epsilon",
"activation_clip_value_tw_epsilon_norm",
"activation_clip_value_normalized_tw_epsilon_norm",
"bias_epsilon",
"bias_epsilon_norm",
"bias_normalized_epsilon_norm",
]
bn_eps_names = [
"weight_epsilon",
"weight_epsilon_norm",
"weight_normalized_epsilon_norm",
"bias_epsilon",
"bias_epsilon_norm",
"bias_normalized_epsilon_norm",
]
for name, module in model.named_modules():
if isinstance(module, (args.conv_type, args.fc_type)):
if hasattr(module, "weight_clip_value"):
tensorboard_logger.add_scalar(
"{}_{}".format(name, "weight_clip_value"),
module.weight_clip_value,
epoch,
)
if hasattr(module, "activation_clip_value"):
tensorboard_logger.add_scalar(
"{}_{}".format(name, "activation_clip_value"),
module.activation_clip_value,
epoch,
)
for weight_eps_name in weight_eps_names:
if hasattr(module, weight_eps_name):
eps = getattr(module, weight_eps_name)
if eps.numel() == 1:
tensorboard_logger.add_scalar(
"{}_{}".format(name, weight_eps_name), eps, epoch,
)
else:
tensorboard_logger.add_histogram(
"{}_{}".format(name, weight_eps_name), eps, epoch,
)
elif isinstance(module, (nn.BatchNorm2d)):
for bn_eps_name in bn_eps_names:
if hasattr(module, bn_eps_name):
eps = getattr(module, bn_eps_name)
if eps.numel() == 1:
tensorboard_logger.add_scalar(
"{}_{}".format(name, weight_eps_name), eps, epoch,
)
else:
tensorboard_logger.add_histogram(
"{}_{}".format(name, weight_eps_name), eps, epoch,
)
tensorboard_logger.add_histogram(
"{}_{}".format(name, bn_eps_name), eps, epoch,
)
logger.info(
"|===>Training Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}".format(
train_error, train_loss, train5_error
)
)
return train_error, train_loss, train5_error