-
Notifications
You must be signed in to change notification settings - Fork 56
/
block_recon.py
191 lines (165 loc) · 7.85 KB
/
block_recon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
from quant.quant_layer import QuantModule, StraightThrough, lp_loss
from quant.quant_model import QuantModel
from quant.quant_block import BaseQuantBlock
from quant.adaptive_rounding import AdaRoundQuantizer
from quant.data_utils import save_grad_data, save_inp_oup_data
def block_reconstruction(model: QuantModel, block: BaseQuantBlock, cali_data: torch.Tensor,
batch_size: int = 32, iters: int = 20000, weight: float = 0.01, opt_mode: str = 'mse',
asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2),
warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0):
"""
Block reconstruction to optimize the output from each block.
:param model: QuantModel
:param block: BaseQuantBlock that needs to be optimized
:param cali_data: data for calibration, typically 1024 training images, as described in AdaRound
:param batch_size: mini-batch size for reconstruction
:param iters: optimization iterations for reconstruction,
:param weight: the weight of rounding regularization term
:param opt_mode: optimization mode
:param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output
:param include_act_func: optimize the output after activation function
:param b_range: temperature range
:param warmup: proportion of iterations that no scheduling for temperature
:param act_quant: use activation quantization or not.
:param lr: learning rate for act delta learning
:param p: L_p norm minimization
"""
model.set_quant_state(False, False)
block.set_quant_state(True, act_quant)
round_mode = 'learned_hard_sigmoid'
if not include_act_func:
org_act_func = block.activation_function
block.activation_function = StraightThrough()
if not act_quant:
# Replace weight quantizer to AdaRoundQuantizer
for name, module in block.named_modules():
if isinstance(module, QuantModule):
module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
weight_tensor=module.org_weight.data)
module.weight_quantizer.soft_targets = True
# Set up optimizer
opt_params = []
for name, module in block.named_modules():
if isinstance(module, QuantModule):
opt_params += [module.weight_quantizer.alpha]
optimizer = torch.optim.Adam(opt_params)
scheduler = None
else:
# Use UniformAffineQuantizer to learn delta
if hasattr(block.act_quantizer, 'delta'):
opt_params = [block.act_quantizer.delta]
else:
opt_params = []
for name, module in block.named_modules():
if isinstance(module, QuantModule):
if module.act_quantizer.delta is not None:
opt_params += [module.act_quantizer.delta]
optimizer = torch.optim.Adam(opt_params, lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.)
loss_mode = 'none' if act_quant else 'relaxation'
rec_loss = opt_mode
loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss,
b_range=b_range, decay_start=0, warmup=warmup, p=p)
# Save data before optimizing the rounding
cached_inps, cached_outs = save_inp_oup_data(model, block, cali_data, asym, act_quant, batch_size)
if opt_mode != 'mse':
cached_grads = save_grad_data(model, block, cali_data, act_quant, batch_size=batch_size)
else:
cached_grads = None
device = 'cuda'
for i in range(iters):
idx = torch.randperm(cached_inps.size(0))[:batch_size]
cur_inp = cached_inps[idx].to(device)
cur_out = cached_outs[idx].to(device)
cur_grad = cached_grads[idx].to(device) if opt_mode != 'mse' else None
optimizer.zero_grad()
out_quant = block(cur_inp)
err = loss_func(out_quant, cur_out, cur_grad)
err.backward(retain_graph=True)
optimizer.step()
if scheduler:
scheduler.step()
torch.cuda.empty_cache()
# Finish optimization, use hard rounding.
for name, module in block.named_modules():
if isinstance(module, QuantModule):
module.weight_quantizer.soft_targets = False
# Reset original activation function
if not include_act_func:
block.activation_function = org_act_func
class LossFunction:
def __init__(self,
block: BaseQuantBlock,
round_loss: str = 'relaxation',
weight: float = 1.,
rec_loss: str = 'mse',
max_count: int = 2000,
b_range: tuple = (10, 2),
decay_start: float = 0.0,
warmup: float = 0.0,
p: float = 2.):
self.block = block
self.round_loss = round_loss
self.weight = weight
self.rec_loss = rec_loss
self.loss_start = max_count * warmup
self.p = p
self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
start_b=b_range[0], end_b=b_range[1])
self.count = 0
def __call__(self, pred, tgt, grad=None):
"""
Compute the total loss for adaptive rounding:
rec_loss is the quadratic output reconstruction loss, round_loss is
a regularization term to optimize the rounding policy
:param pred: output from quantized model
:param tgt: output from FP model
:param grad: gradients to compute fisher information
:return: total loss function
"""
self.count += 1
if self.rec_loss == 'mse':
rec_loss = lp_loss(pred, tgt, p=self.p)
elif self.rec_loss == 'fisher_diag':
rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean()
elif self.rec_loss == 'fisher_full':
a = (pred - tgt).abs()
grad = grad.abs()
batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1)
rec_loss = (batch_dotprod * a * grad).mean() / 100
else:
raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
b = self.temp_decay(self.count)
if self.count < self.loss_start or self.round_loss == 'none':
b = round_loss = 0
elif self.round_loss == 'relaxation':
round_loss = 0
for name, module in self.block.named_modules():
if isinstance(module, QuantModule):
round_vals = module.weight_quantizer.get_soft_targets()
round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
else:
raise NotImplementedError
total_loss = rec_loss + round_loss
if self.count % 500 == 0:
print('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
float(total_loss), float(rec_loss), float(round_loss), b, self.count))
return total_loss
class LinearTempDecay:
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2):
self.t_max = t_max
self.start_decay = rel_start_decay * t_max
self.start_b = start_b
self.end_b = end_b
def __call__(self, t):
"""
Cosine annealing scheduler for temperature b.
:param t: the current time step
:return: scheduled temperature
"""
if t < self.start_decay:
return self.start_b
else:
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))