In [2]:
import sys
sys.path.append('../smoothquant')
from smooth import *
from fake_quant import *

In [3]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like
import tqdm

In [4]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

In [5]:
from datasets import load_dataset

tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
evaluator = Evaluator(dataset, tokenizer, "cuda")

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [6]:
import functools

def get_act_scales(model, sample_text):
    model.eval()  # Set the model to evaluation mode
    act_scales_all = {}  # Dictionary to store the activations of all layers

    def stat_tensor(name, tensor):
        hidden_dim = tensor.shape[-1]
        # shape: (batch_size, seq_len, hidden_dim) -> (batch_size * seq_len, hidden_dim)
        # clone() is used to prevent modifying the original tensor
        tensor_original = tensor.clone().view(-1, hidden_dim).detach()
        act_scales_all[name] = tensor_original

    # module: is the layer being hooked
    def stat_input_hook(module, input, output, name):
        input_tensor = input[0] if isinstance(input, tuple) else input
        stat_tensor(name, input_tensor)

    hooks = []
    for name, module in model.named_modules():
        # if isinstance(module, (nn.Linear, W8A8Linear)):
        if isinstance(module, (nn.Linear, W8A8Linear, W4A4Linear)):
            print(f"Hooking layer: {name}")
            hook = module.register_forward_hook(
                functools.partial(stat_input_hook, name=name)
            )
            hooks.append(hook)

    # Tokenize the sample
    tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    input_ids = tokenizer(sample_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024).input_ids

    # Move input_ids to the same device as the model
    input_ids = input_ids.to(model.device)

    # Forward pass
    with torch.no_grad():
        model(input_ids)

    for hook in hooks:
        hook.remove()

    return act_scales_all

# 固定樣本文字
def get_sample_text():
    tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    sample = dataset[torch.randint(len(dataset), (1,)).item()]["text"]
    print("固定樣本文字:", sample)  # 查看固定的樣本文字
    return sample

# Example usage
sample_text = get_sample_text()  # 第一次執行時抽取樣本
# act_scales = get_act_scales(model, sample_text)


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


固定樣本文字: 


## quantizing 0~3layer to 4bit


In [7]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "../../llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm_layer(model_fp16, act_scales, 0.85,  quant_layers=range(4))

Processing layer 0
Module: Linear(in_features=4096, out_features=4096, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Module: Linear(in_features=4096, out_features=4096, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Module: Linear(in_features=4096, out_features=4096, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Module: Linear(in_features=4096, out_features=11008, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Module: Linear(in_features=4096, out_features=11008, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 0
Processing layer 1
Module: Linear(in_features=4096, out_features=4096, bias=False), Type: <class 'torch.nn.modules.linear.Linear'>
Module: Linear(in_features=4096, out_feat

In [9]:
model_smoothquant_layer0_3 = quantize_llama_like_layer(model_fp16)

Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 0
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 1
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 2
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing layer 3
Quantizing l

In [10]:
print(model_smoothquant_layer0_3)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W4A4Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W4A4Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

In [11]:
smoothquant_layer0_3 = evaluator.evaluate(model_smoothquant_layer0_3)
print(f"Original model (fp16) perplexity: {smoothquant_layer0_3}")

Evaluating...: 100%|██████████| 40/40 [00:20<00:00,  1.96it/s]


Original model (fp16) perplexity: 168.82559204101562


## Quantize layer 5~29

In [6]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "../../llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm_layer(model_fp16, act_scales, 0.85,  quant_layers=range(5,30))

Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 0
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 1
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 2
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 3
Skipping layer 4
Skipping layer 4
Skipping layer

In [8]:
model_smoothquant_layer4_29 = quantize_llama_like_layer(model_fp16, quant_layers=range(5,30))

Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 0 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 1 for quantization
Skipping layer 2 for quantization
Skipping layer

In [9]:
print(model_smoothquant_layer4_29)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (1): LlamaDecoderLayer(
        (self_attn): LlamaAtten

In [10]:
ppl_smoothquant_layer4_29 = evaluator.evaluate(model_smoothquant_layer4_29)
print(f"Original model (fp16) perplexity: {ppl_smoothquant_layer4_29}")

Evaluating...: 100%|██████████| 40/40 [00:22<00:00,  1.81it/s]


Original model (fp16) perplexity: 1427.8328857421875


: 