# SmoothQuant on Llama 2 7B

In this notebook, we use Llama-2-7B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the similar perplexity as FP16 models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like
import tqdm

The following is an evaluator to see the performance of the model. We use a toy dataset (the first 40 examples in the test set of the Wikitext-2 dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

In [2]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

In [3]:
from datasets import load_dataset

tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
evaluator = Evaluator(dataset, tokenizer, "cuda")

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## FP16 Model Perplexity

Let's first check the performance of the original FP16 model.

In [9]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "../../llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:

from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# tokenizer = AutoTokenizer.from_pretrained("../../llama-2-7b-hf")

model_fp16.eval()
model_fp16 = model_fp16.to(device)

input_text = "explain what is AI"
inputs = tokenizer(input_text, return_tensors="pt",add_special_tokens=True).to(device)

output = model_fp16.generate(
    inputs.input_ids,
    max_length=50,
    eos_token_id=tokenizer.eos_token_id,
    do_sample=True,
    use_cache=False
)

generated_text = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("Generated text:", generated_text)

Using device: cuda
Generated text: ['explain what is AI\nAi is the ability of a computer to mimic human intelligence.\nIt is a subfield of computer science that studies the development of intelligent machines that can think and act like humans.\nIt is']


In [5]:
ppl_fp16 = evaluator.evaluate(model_fp16)
print(f"Original model (fp16) perplexity: {ppl_fp16}")

Evaluating...: 100%|██████████| 40/40 [00:09<00:00,  4.04it/s]


Original model (fp16) perplexity: 5.822948932647705


We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Perplexity

In [11]:
model_w8a8 = quantize_llama_like(model_fp16)
print(model_w8a8)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

In [13]:
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# tokenizer = AutoTokenizer.from_pretrained("../../llama-2-7b-hf")

model_w8a8.eval()
model_w8a8 = model_w8a8.to(device)

input_text = "explain what is AI"
inputs = tokenizer(input_text, return_tensors="pt",add_special_tokens=True).to(device)

output = model_w8a8.generate(
    inputs.input_ids,
    max_length=100,
    eos_token_id=tokenizer.eos_token_id,
    do_sample=True,
    use_cache=False
)

generated_text = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("Generated text:", generated_text)

Using device: cuda
Generated text: ['explain what is AI I don’t know who the hell. Unterscheidung 4519465.doc 15.jpg Mitarbeit Dus2022018802010010202001010101010000010001010110101010101010101010']


In [7]:
ppl_w8a8 = evaluator.evaluate(model_w8a8)
print(f"Naive W8A8 quantized model perplexity: {ppl_w8a8}")

Evaluating...: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]

Naive W8A8 quantized model perplexity: 5.931240558624268





We can see there is a perplexity increase. We then use SmoothQuant to quantize the model and check the performance.

## Naive W4A4 attn W8A8 MLP


In [5]:
model_w4a4 = quantize_llama_like(model_fp16)
print(model_w4a4)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

In [8]:
ppl_w4a4 = evaluator.evaluate(model_w4a4)
print(f"Naive W8A8 quantized model perplexity: {ppl_w4a4}")

Evaluating...: 100%|██████████| 40/40 [00:21<00:00,  1.82it/s]

Naive W8A8 quantized model perplexity: 811.0993041992188





## SmoothQuant W8A8 Quantized Model Perplexity

In [4]:
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:

act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm(model, act_scales, 0.85)
model_smoothquant_w8a8 = quantize_llama_like(model)
print(model_smoothquant_w8a8)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

We can see the smoothed model has a lower perplexity which is close to the FP16 model's. This is because SmoothQuant smooths the outliers in activations and balances the quantization difficulty of activations and weights.

In [12]:
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("../../llama-2-7b-hf")

model_smoothquant_w8a8.eval()
model_smoothquant_w8a8 = model_smoothquant_w8a8.to(device)

input_text = "explain what is AI"
inputs = tokenizer(input_text, return_tensors="pt",add_special_tokens=True).to(device)

output = model_smoothquant_w8a8.generate(
    inputs.input_ids,
    max_length=50,
    eos_token_id=tokenizer.eos_token_id,
    do_sample=True,
)

generated_text = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("Generated text:", generated_text)

Using device: cuda
Generated text: ['explain what is AI\nЉ everybodyК everybodyу��.\nIn the first place.\nThe only.\nA.\nI.\nR nobody’s.\nIn the name of the first.\nF.\nH']


In [9]:
ppl_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)
print(f"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w8a8}")

Evaluating...: 100%|██████████| 40/40 [00:08<00:00,  4.49it/s]

SmoothQuant W8A8 quantized model perplexity: 5.85634183883667





In [14]:
print(model_fp16)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

In [4]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "../../llama2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm(model_fp16, act_scales, 0.85)
model_smoothquant_w4a4 = quantize_llama_like(model_fp16)
print(model_smoothquant_w4a4)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W4A4Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
         

In [None]:
ppl_smoothquant_w4a4 = evaluator.evaluate(model_smoothquant_w4a4)
print(f"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w4a4}")

Evaluating...: 100%|██████████| 40/40 [00:22<00:00,  1.75it/s]

SmoothQuant W8A8 quantized model perplexity: 499.2110290527344





: 

In [12]:
print(act_scales)

{'model.layers.0.self_attn.q_proj': tensor([0.1150, 0.0492, 0.0074,  ..., 0.0454, 0.0423, 0.0233]), 'model.layers.0.self_attn.k_proj': tensor([0.1150, 0.0492, 0.0074,  ..., 0.0454, 0.0423, 0.0233]), 'model.layers.0.self_attn.v_proj': tensor([0.1150, 0.0492, 0.0074,  ..., 0.0454, 0.0423, 0.0233]), 'model.layers.0.self_attn.o_proj': tensor([0.0191, 0.0346, 0.0181,  ..., 0.0137, 0.0127, 0.0158]), 'model.layers.0.mlp.gate_proj': tensor([0.1866, 0.1777, 0.1694,  ..., 0.1953, 0.1864, 0.1858]), 'model.layers.0.mlp.up_proj': tensor([0.1866, 0.1777, 0.1694,  ..., 0.1953, 0.1864, 0.1858]), 'model.layers.0.mlp.down_proj': tensor([0.1345, 0.0787, 0.2124,  ..., 0.2271, 0.0812, 0.4741]), 'model.layers.1.self_attn.q_proj': tensor([0.3853, 0.3708, 0.3867,  ..., 0.2445, 0.3130, 0.2749]), 'model.layers.1.self_attn.k_proj': tensor([0.3853, 0.3708, 0.3867,  ..., 0.2445, 0.3130, 0.2749]), 'model.layers.1.self_attn.v_proj': tensor([0.3853, 0.3708, 0.3867,  ..., 0.2445, 0.3130, 0.2749]), 'model.layers.1.self

In [2]:
pip show transformers


Name: transformers
Version: 4.36.0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /home/master_112/m56121041/miniconda3/envs/smoothquant/lib/python3.8/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: 
Note: you may need to restart the kernel to use updated packages.
