Skip to content

Commit

Permalink
New quant optimization procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Dec 8, 2023
1 parent 5c97425 commit 2e91239
Show file tree
Hide file tree
Showing 11 changed files with 1,291 additions and 95 deletions.
249 changes: 249 additions & 0 deletions conversion/optimize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from conversion.qparams import QParams, qparams_options
from conversion.qparams_stats import qparams_stats
import math
import itertools

def optimize(job, save_fn):

Expand Down Expand Up @@ -97,3 +101,248 @@ def optimize(job, save_fn):
max_rfn = target_rfn
else:
min_rfn = target_rfn


def optimize_new(job, save_fn, model):

key = "model.layers.0"
key_q = key + ".self_attn.q_proj"
key_k = key + ".self_attn.k_proj"
key_v = key + ".self_attn.v_proj"
key_o = key + ".self_attn.o_proj"
key_g = key + ".mlp.gate_proj"
key_u = key + ".mlp.up_proj"
key_d = key + ".mlp.down_proj"
shape_q = model.modules_dict[key_q].matrix_shape()
shape_k = model.modules_dict[key_k].matrix_shape()
shape_v = model.modules_dict[key_v].matrix_shape()
shape_o = model.modules_dict[key_o].matrix_shape()
shape_g = model.modules_dict[key_g].matrix_shape()
shape_u = model.modules_dict[key_u].matrix_shape()
shape_d = model.modules_dict[key_d].matrix_shape()
numel_q = shape_q[0] * shape_q[1]
numel_k = shape_k[0] * shape_k[1]
numel_v = shape_v[0] * shape_v[1]
numel_o = shape_o[0] * shape_o[1]
numel_g = shape_g[0] * shape_g[1]
numel_u = shape_u[0] * shape_u[1]
numel_d = shape_d[0] * shape_d[1]

num_layers = model.config.num_hidden_layers
numel = num_layers * (numel_q + numel_k + numel_v + numel_o + numel_g + numel_u + numel_d)
target_bpw = job["bits"]
weight_budget = numel * target_bpw

layer_p1 = num_layers // 2
layer_p2 = num_layers * 3 // 4
layer_p3 = num_layers - 1
assert 2 < layer_p1 < layer_p2 < layer_p3

# Now it's a knapsack problem all of a sudden

weights = []
values = []
params = []
for i in range(num_layers * 2):
weights.append([])
values.append([])
params.append([])

for qcosts in qparams_stats:
mode_q, mode_k, mode_v, mode_o, mode_g, mode_u, mode_d = qcosts[:7]

if mode_q:
bits = 0
bits += mode_q.total_bits(shape_q)
bits += mode_k.total_bits(shape_k)
bits += mode_v.total_bits(shape_v)
bits += mode_o.total_bits(shape_o)
index = 0

else:
bits = 0
bits += mode_g.total_bits(shape_g)
bits += mode_u.total_bits(shape_u)
bits += mode_d.total_bits(shape_d)
index = 1

layer_kldiv = qcosts[7:]
for layer in range(num_layers):
if layer == 0:
kldiv = layer_kldiv[0]
elif layer == 1:
kldiv = layer_kldiv[1]
elif layer == 2:
kldiv = layer_kldiv[2]
elif layer < layer_p1:
a = (layer_p1 - layer) / (layer_p1 - 2)
b = (layer - 2) / (layer_p1 - 2)
kldiv = a * layer_kldiv[2] + b * layer_kldiv[3]
elif layer == layer_p1:
kldiv = layer_kldiv[3]
elif layer < layer_p2:
a = (layer_p2 - layer) / (layer_p2 - layer_p1)
b = (layer - layer_p1) / (layer_p2 - layer_p1)
kldiv = a * layer_kldiv[3] + b * layer_kldiv[4]
elif layer == layer_p2:
kldiv = layer_kldiv[4]
elif layer < layer_p3:
a = (layer_p3 - layer) / (layer_p3 - layer_p2)
b = (layer - layer_p2) / (layer_p3 - layer_p2)
kldiv = a * layer_kldiv[4] + b * layer_kldiv[5]
else:
kldiv = layer_kldiv[5]

weights[2 * layer + index].append(bits)
values[2 * layer + index].append(kldiv)
params[2 * layer + index].append(qcosts)

print(" -- Pruning...")

# Sort options by weight, eliminate strictly worse options

for i in range(num_layers * 2):
combined = sorted(zip(weights[i], values[i], params[i]))
w_, v_, p_ = zip(*combined)
w_ = list(w_)
v_ = list(v_)
p_ = list(p_)
j = 1
while j < len(v_):
if v_[j] >= v_[j - 1]:
w_.pop(j)
v_.pop(j)
p_.pop(j)
else:
j += 1
weights[i] = w_
values[i] = v_
params[i] = p_

# Quick and dirty iterative solver

print(" -- Solving...")

f_solution = [0] * num_layers * 2
weight = sum(weights[i][0] for i in range(num_layers * 2))
value = sum(values[i][0] for i in range(num_layers * 2))

def improve(solution, s_weight, hold = None):

if hold is None: hold = []
best_idx = -1
best_ratio = 0
best_add_w = 0
best_add_v = 0
for idx in range(num_layers * 2):
if idx in hold: continue

si = solution[idx]
if si == len(weights[idx]) - 1: continue

add_w = weights[idx][si + 1] - weights[idx][si]
if s_weight + add_w > weight_budget: continue

add_v = values[idx][si + 1] - values[idx][si]
ratio = -add_v / add_w
if ratio > best_ratio:
best_ratio = ratio
best_idx = idx
best_add_w = add_w
best_add_v = add_v

return best_idx, best_add_w, best_add_v

while True:
b_idx, b_add_w, b_add_v = improve(f_solution, weight)
if b_idx == -1:
break

f_solution[b_idx] += 1
weight += b_add_w
value += b_add_v

bpw = weight / numel
print(f" -- Estimated divergence: {value:.8f} bpw: {bpw:.4f}")

best_value = value
prev_best_value = value
step_size = 1

while True:

for i, j in itertools.permutations(range(num_layers * 2), 2):

t_solution = f_solution.copy()
t_solution[i] = max(t_solution[i] - step_size, 0)
t_solution[j] = max(t_solution[j] - step_size, 0)

t_weight = sum(weights[i][t_solution[i]] for i in range(num_layers * 2))
t_value = sum(values[i][t_solution[i]] for i in range(num_layers * 2))

while True:
b_idx, b_add_w, b_add_v = improve(t_solution, t_weight, [i, j])
if b_idx == -1:
break
t_solution[b_idx] += 1
t_weight += b_add_w
t_value += b_add_v

if t_value < best_value:
f_solution = t_solution
best_value = t_value
break

if best_value == prev_best_value:
step_size += 1
if step_size > 2: break
continue

bpw = t_weight / numel
print(f" -- Estimated divergence: {best_value:.8f} bpw: {bpw:.4f}")
prev_best_value = best_value

# Compile as measurement

print(" -- Quantization strategy:")

job["measurement"] = []
for layer_ in range(num_layers):

key = "model.layers." + str(layer_)
key_q = key + ".self_attn.q_proj"
key_k = key + ".self_attn.k_proj"
key_v = key + ".self_attn.v_proj"
key_o = key + ".self_attn.o_proj"
key_g = key + ".mlp.gate_proj"
key_u = key + ".mlp.up_proj"
key_d = key + ".mlp.down_proj"

qp1 = params[layer_ * 2][f_solution[layer_ * 2]]
qp2 = params[layer_ * 2 + 1][f_solution[layer_ * 2 + 1]]
mode_q, mode_k, mode_v, mode_o, _, _, _ = qp1[:7]
_, _, _, _, mode_g, mode_u, mode_d = qp2[:7]

def store_res(key_, numel_, mode_, shape_):
bpw_ = mode_.bpw(shape_)
desc_ = mode_.get_desc()
print(f" -- {key_:40} bpw: {bpw_:.4f} mode: {desc_}")
job["measurement"].append({
"key": key_,
"numel": numel_,
"best_option": {
"desc": desc_,
"bpw": bpw_,
"total_bits": mode_.total_bits(shape_),
"err": 0,
"qparams": mode_.get_dict()
}
})

store_res(key_q, numel_q, mode_q, shape_q)
store_res(key_k, numel_k, mode_k, shape_k)
store_res(key_v, numel_v, mode_v, shape_v)
store_res(key_o, numel_o, mode_o, shape_o)
store_res(key_g, numel_g, mode_g, shape_g)
store_res(key_u, numel_u, mode_u, shape_u)
store_res(key_d, numel_d, mode_d, shape_d)

0 comments on commit 2e91239

Please sign in to comment.