Skip to content

Commit

Permalink
Merge pull request #416 from turboderp/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
turboderp committed Apr 19, 2024
2 parents dafb508 + b68c0bd commit ed118b4
Show file tree
Hide file tree
Showing 33 changed files with 12,395 additions and 189 deletions.
6 changes: 5 additions & 1 deletion conversion/adaptivegptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ def find_params(self, x):
self.scale = qscale_tw * best_p
self.qscale_max = qscale_max_t * best_p

# Make sure scales are rounded correctly for sanity test
prescale = torch.tensor([1 / 256], dtype = torch.half, device = self.scale.device)
self.scale = ((self.qscale * self.qscale).to(torch.half) * (self.qscale_max.half() * prescale)).float()


class AdaptiveGPTQ:

percdamp: float = 0.07
percdamp: float = 0.12

layer: nn.Linear
device: torch.device
Expand Down
2 changes: 1 addition & 1 deletion conversion/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def measure_quant(job, save_fn, model):
overall_rolling_accuracy = 0

last_snapshot_time = time.time()
snapshot_interval_s = 90
snapshot_interval_s = 180

temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
Expand Down
234 changes: 79 additions & 155 deletions conversion/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from conversion.qparams import QParams
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
import math
import itertools
import time

def optimize(job, save_fn, model):

Expand All @@ -9,11 +11,19 @@ def optimize(job, save_fn, model):
mlp_key_up = model.config.arch.mlp_key_up
mlp_key_down = model.config.arch.mlp_key_down

error_norm = 2.4
max_step_size = 2
first_layer_bias = 10
bias_layers = 2
bias_iter = 10
norm_interval = (1.5, 3.5)
norm_2ndstage = 0.15
anneal_temp_max = 2
anneal_temp_min = 0.0001
anneal_cooling_factor = 0.995
anneal_iter = 1000
anneal_samples = 80
anneal_stages = 3

# max_step_size = 2
# first_layer_bias = 4
# bias_layers = 2
# bias_iter = 0

key = "model.layers.0"
key_q = key + ".self_attn.q_proj"
Expand Down Expand Up @@ -57,21 +67,14 @@ def optimize(job, save_fn, model):
numel = sum(m.numel() for m in model.modules[1 : num_modules + 1])

target_bpw = job["bits"]
weight_budget = numel * target_bpw
weight_budget = int(numel * target_bpw)

# Compile options

measurement = job["measurement"]

def fn(x, idx):
if idx < bias_layers:
return 1 - ((1 - x) ** error_norm) * first_layer_bias
else:
return 1 - ((1 - x) ** error_norm)

weights = []
values = []
slots = []
params = []

for i in range(num_layers):
if model.config.arch.parallel_decoder_blocks:
m1 = measurement["model.layers." + str(i) + ".parallel_decoder"]["attn"]
Expand All @@ -80,162 +83,83 @@ def fn(x, idx):
m1 = measurement["model.layers." + str(i) + ".self_attn"]
m2 = measurement["model.layers." + str(i) + "." + mlp_mode]
for m in [m1, m2]:
v = [fn(e["accuracy"], i) for e in m]
w = [e["total_bits"] for e in m]
weights.append(w)
values.append(v)
params.append(m)

print(" -- Pruning...")

# Sort options by weight, eliminate strictly worse options

for i in range(num_layers * 2):
combined = sorted(zip(weights[i], values[i], params[i]))
w_, v_, p_ = zip(*combined)
w_ = list(w_)
v_ = list(v_)
p_ = list(p_)
j = 1
while j < len(v_):
if v_[j] <= v_[j - 1]:
w_.pop(j)
v_.pop(j)
p_.pop(j)
else:
j += 1
weights[i] = w_
values[i] = v_
params[i] = p_

# Quick and dirty iterative solver

print(" -- Solving...")

f_solution = [0] * num_layers * 2
weight = sum(weights[i][0] for i in range(num_layers * 2))
value = 1
for i in range(num_layers * 2): value *= values[i][0]

iteration = 0

while True:
min_idx = -1
min_value = float("inf")
iteration += 1
for i in range(bias_layers if iteration < bias_iter else num_layers * 2):
s = f_solution[i]
if values[i][s] < min_value:
if s < len(weights[i]) - 1:
added_w = weights[i][s + 1] - weights[i][s]
if added_w + weight <= weight_budget:
min_idx = i
min_value = values[i][s]
if min_idx == -1: break
s = f_solution[min_idx]
weight += weights[min_idx][s + 1] - weights[min_idx][s]
value *= values[min_idx][s + 1] / values[min_idx][s]
f_solution[min_idx] += 1

bpw = weight / numel
print(f" -- Score: {value:.8f} bpw: {bpw:.4f}")

def improve(solution, s_weight, hold = None):

if hold is None: hold = []
best_idx = -1
best_ratio = 0
best_add_w = 0
best_add_v = 0
for idx in range(num_layers * 2):
if idx in hold: continue

si = solution[idx]
if si == len(weights[idx]) - 1: continue

add_w = weights[idx][si + 1] - weights[idx][si]
if s_weight + add_w > weight_budget: continue

add_v = values[idx][si + 1] / values[idx][si]
ratio = add_v / add_w
if ratio > best_ratio:
best_ratio = ratio
best_idx = idx
best_add_w = add_w
best_add_v = add_v

return best_idx, best_add_w, best_add_v

# while True:
# b_idx, b_add_w, b_add_v = improve(f_solution, weight)
# if b_idx == -1:
# break
#
# f_solution[b_idx] += 1
# weight += b_add_w
# value += b_add_v
#
# bpw = weight / numel
# print(f" -- Score: {math.exp(value):.8f} bpw: {bpw:.4f}")

best_value = value
prev_best_value = value
step_size = 1

while True:

for i, j in itertools.permutations(range(num_layers * 2), 2):

t_solution = f_solution.copy()
t_solution[i] = max(t_solution[i] - step_size, 0)
t_solution[j] = max(t_solution[j] - step_size, 0)

t_weight = sum(weights[k][t_solution[k]] for k in range(num_layers * 2))
t_value = 1
for k in range(num_layers * 2): t_value *= values[k][t_solution[k]]

while True:
b_idx, b_add_w, b_add_v = improve(t_solution, t_weight, [i, j])
if b_idx == -1:
break
t_solution[b_idx] += 1
t_weight += b_add_w
t_value *= b_add_v

if t_value > best_value:
f_solution = t_solution
best_value = t_value
break

if best_value == prev_best_value:
step_size += 1
if step_size > max_step_size: break
continue

bpw = t_weight / numel
print(f" -- Score: {best_value:.8f} bpw: {bpw:.4f}")
prev_best_value = best_value
slot = []
param = []
for opt in m:
o = (int(opt["total_bits"]), 1 - opt["accuracy"])
slot.append(o)
param.append(opt)
slots.append(slot)
params.append(param)

# Find some solutions

last_update = 0
m = float("inf")
p = float("inf")
for i in range(anneal_stages * anneal_samples):
if time.time() - last_update > 1 or i == anneal_samples - 1:
print(f" -- Optimizing: {i + 1:4}/{anneal_stages * anneal_samples:4}")
last_update = time.time()

if i < anneal_samples:
t = i / (anneal_samples - 1)
norm = (1 - t) * norm_interval[0] + t * norm_interval[1]

elif i < anneal_samples * 2:
if i == anneal_samples:
norm_a = bestnorm - norm_2ndstage / 2
norm_b = bestnorm + norm_2ndstage / 2
t = i / (anneal_samples - 1) - 1
norm = (1 - t) * norm_a + t * norm_b

else:
norm = bestnorm

s_, si_, p_, c_, m_ = ext_c.sim_anneal(slots,
weight_budget,
anneal_temp_max,
anneal_cooling_factor,
anneal_temp_min,
anneal_iter,
norm)

if i < anneal_samples * 2:
if m_ < m:
m = m_
bestnorm = norm
else:
if p_ < p:
s, si, p, m = s_, si_, p_, m_

solution_idx = si
print(f" -- max(err): {m:.6f}")
print(f" -- error_norm: {bestnorm:.6f}")


# Save strategy

print(" -- Quantization strategy:")

errp = 1
logerr = 0
maxerr = 0
job["strategy"] = {}
for layer_ in range(num_layers):

k1 = "model.layers." + str(layer_) + ".self_attn"
k2 = "model.layers." + str(layer_) + "." + mlp_mode
p1 = params[layer_ * 2][f_solution[layer_ * 2]]
p2 = params[layer_ * 2 + 1][f_solution[layer_ * 2 + 1]]
p1 = params[layer_ * 2][solution_idx[layer_ * 2]]
p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]]

for (k, p, n) in zip((k1, k2), (p1, p2), (numel_attn, numel_mlp)):
job["strategy"][k] = p
bpw = p["total_bits"] / n
err = 1 - p["accuracy"]
print(f" -- {k:50} {bpw:1.4f} bpw - exp. error: {err:1.8f}")
errp *= (1 - err)
logerr += math.log(err)
maxerr = max(err, maxerr)

print(f" -- Total exp. error: {1 - errp:1.12f}")
print(f" -- sum(log(err)): {logerr:.6f}")
print(f" -- max(err): {maxerr:.6f}")

xx = 0
4 changes: 2 additions & 2 deletions conversion/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def quant_parallel_decoder(job, module, hidden_states, target_states, quantizers
def quant(job, save_fn, model):

last_snapshot_time = time.time()
snapshot_interval_s = 90
snapshot_interval_s = 180

temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
Expand Down Expand Up @@ -526,4 +526,4 @@ def quant(job, save_fn, model):
del job["invalid"]
save_fn()

time_since_snapshot = time.time()
last_snapshot_time = time.time()
3 changes: 3 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from conversion.optimize import optimize
from conversion.compile import compile_model
from conversion.qparams import qparams_headoptions
import torch

parser = argparse.ArgumentParser(description = "Convert model to ExLlamaV2")
parser.add_argument("-i", "--in_dir", type = str, help = "Input directory", default = "")
Expand All @@ -29,6 +30,8 @@

args = parser.parse_args()

torch.set_printoptions(precision = 7, sci_mode = False, linewidth = 200)

# Check some args

if not args.in_dir:
Expand Down
33 changes: 19 additions & 14 deletions doc/qcache_eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@ The tl;dr:
Token-level perplexity tests for various full-precision and quantized models using FP16, FP8 and Q4 cache
modes. Dataset is The Pile, 10 rows of 512 tokens per test.

Model | Precision | FP16 cache | FP8 cache | Q4 cache
--------|-----------|---------------|-----------|---------
Mistral 7B Instruct | 3.0 bpw | 13.33 | 13.43 | 13.41
-- | 3.5 bpw | 13.07 | 13.14 | 13.12
-- | 4.0 bpw | 12.90 | 12.90 | 12.90
-- | 5.0 bpw | 12.73 | 12.73 | 12.75
-- | 6.0 bpw | 12.73 | 12.75 | 12.74
-- | FP16 | 12.69 | 12.71 | 12.72
Mixtral 8x7B | 3.5 bpw | 10.27 | 10.41 | 10.39
-- | 4.0 bpw | 10.09 | 10.26 | 10.23
-- | 5.0 bpw | 10.02 | 10.16 | 10.15
Llama2 7B | 4.0 bpw | 11.43 | 11.92 | 11.74
-- | 5.0 bpw | 11.13 | 11.40 | 11.31
-- | FP16 | 10.91 | 11.24 | 11.16
Results are updated for the new method which uses Hadamard rotations on the keys/values. Old results for version
0.0.18 and prior kept for reference.

Model | Precision | FP16 cache | FP8 cache | Q4 cache (old) | Q4 cache
--------|---------|-------------|-----------|-------|----------
Mistral 7B Instruct | 3.0 bpw | **13.33** | 13.43 | 13.41 | **13.37**
-- | 3.5 bpw | **13.07** | 13.14 | 13.12 | **13.09**
-- | 4.0 bpw | **12.90** | 12.90 | 12.90 | **12.90**
-- | 5.0 bpw | **12.73** | 12.73 | 12.75 | **12.75**
-- | 6.0 bpw | **12.73** | 12.75 | 12.74 | **12.74**
-- | FP16 | **12.69** | 12.71 | 12.72 | **12.69**
Mixtral 8x7B | 3.5 bpw | **10.27** | 10.41 | 10.39 | **10.32**
-- | 4.0 bpw | **10.09** | 10.26 | 10.23 | **10.19**
-- | 5.0 bpw | **10.02** | 10.16 | 10.15 | **10.04**
Llama2 7B | 4.0 bpw | **11.43** | 11.92 | 11.74 | **11.60**
-- | 5.0 bpw | **11.13** | 11.40 | 11.31 | **11.19**
-- | FP16 | **10.91** | 11.24 | 11.16 | **11.05**


### HumanEval
Expand All @@ -37,6 +40,8 @@ The following are HumanEval tests on various full-precision and quantized models
respectively. Number of samples per task is limited to 10 (still giving 39360 completions in total produced
over about 24 hours.)

The following tests were done prior to the improvements in 0.0.18-dev.

#### pass@1

Model | Precision | FP16 cache | Q4 cache | diff
Expand Down

0 comments on commit ed118b4

Please sign in to comment.