
<p style="text-align: center; font-weight: bold;">SmoothQuant</p>

![image](../images/smq_intuition.png)


In [72]:
%ls

[0m[01;34mimages[0m/  QAT.ipynb      Quantization_compare_minmax_percentile.ipynb  SMQ.ipynb
[01;34mPTQ[0m/     [01;34mquantization[0m/  Quantization-Implementation.ipynb


In [4]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import functools
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


In [12]:
# Poor Mans NN
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, 64)
        self.linear3 = nn.Linear(64, 10)
           
    def forward(self, x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        return x

transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

calibration_data = DataLoader(mnist_dataset, batch_size=64, shuffle=False)

model = NN()
model.load_state_dict(torch.load("./linear_nn.pth", weights_only=True)) 


<All keys matched successfully>

In [13]:
model

NN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=10, bias=True)
)

In [18]:
def get_activations(model, calibration_data):

    model.eval()
    device = next(model.parameters()).device
    activations = {}

    def stat_tensor(name, tensor):
        """Update activation scales for the given tensor."""
        hidden_dim = tensor.shape[-1]
        tensor = tensor.view(-1, hidden_dim).abs().detach()
        comming_max = torch.max(tensor, dim=0)[0].float().cpu()
        if name in activations:
            activations[name] = torch.max(activations[name], comming_max)
        else:
            activations[name] = comming_max

    def stat_input_hook(m, x, y, name):
        """Hook function to capture input activations."""
        if isinstance(x, tuple):
            x = x[0]  # In case x is a tuple, extract the first element.
        stat_tensor(name, x)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(functools.partial(stat_input_hook, name=name))
            )

    for images, labels in tqdm(calibration_data, desc="Processing batches"):
        inputs = images.to(device) 
        model(inputs)

    for h in hooks:
        h.remove()
    return activations


# Fetch Activations for the Calibration Data
activations = get_activations(model, calibration_data)


Processing batches: 100%|██████████| 938/938 [00:04<00:00, 206.23it/s]


In [19]:
for name, layer in model.named_children():
    if name in activations:
        print(f"Layer Name: {name}, Layer: {layer}, Activation: {activations[name].shape}")

Layer Name: linear1, Layer: Linear(in_features=784, out_features=128, bias=True), Activation: torch.Size([784])
Layer Name: linear2, Layer: Linear(in_features=128, out_features=64, bias=True), Activation: torch.Size([128])
Layer Name: linear3, Layer: Linear(in_features=64, out_features=10, bias=True), Activation: torch.Size([64])


# Understanding Linear Layer Smoothing

Now we see that `linear1` and `linear2` can be smoothed, as the input of `linear2` is the output of `linear1`.

### Model Operations
1. **Input fed into the Model:**
   - `Xin`
2. **Linear Layer 1:**
   - `X1_out = Xin * W1.T + B1`
3. **Linear Layer 2:**
   - `X2_out = X1_out * W2.T + B2`
4. **Linear Layer 3:**
   - `Out = X2_out * W3.T + B3`

### Smoothing Linear2
To smoothen `linear2`, we derive activation scales from `X1_out` using the function `stat_tensor`. The process involves:

1. **Compute Max Weight:**  
   Use the function `smooth_linear_linear` to compute the maximum weight.

2. **Compute Scale:**  
   Use a scale factor `alpha = 0.5`.

3. **Scale Adjustments:**
   - Multiply the scale with `W2`.
   - Divide the same scale with `W1`.

Since the operations are associative, these transformations are mathematically equivalent.

### Key Insight
This process ensures smoother transitions between layers without altering the mathematical equivalence of the operations.


![image](images/smq_hardness_mitigation.png)



![image](images/smq_example.png)



![image](images/smq_formula.png)
![image](images/smq_scale.png)

It is known that activation are difficult to quantize as their range and variance are quite high and are susceptible to the input fed.

Hence, to mitigate this SmoothQuant postulates that the diffculty from the activation could be partially handed over to the weights of the previous layer.
So that the difficutly is balanced betweent he weights and the activations. 

The main idea of SmoothQuant is to have Alpha = 0.5 for the scale such that the distribution of the difficulty is shared between the current layer weights and the activations.

NOTE: The activations here mean that the scale is carried over the previous layer weight, thus having no overhead of the scale.

In [63]:

@torch.no_grad()
def smooth_linear_linear(fc1, fc2, act_scales, alpha=0.5):
    """Apply weight smoothing to a pair of linear layers. """
    
    assert isinstance(fc1, nn.Linear), "fc1 must be an instance of nn.Linear"
    assert isinstance(fc2, nn.Linear), "fc2 must be an instance of nn.Linear"
    assert fc1.out_features == fc2.in_features, "fc1 out_features must match fc2 in_features"
    assert fc1.out_features == act_scales.numel(), "Number of act_scales must match fc1 out_features"

    device, dtype = fc1.weight.device, fc1.weight.dtype
    act_scales = act_scales.to(device=device, dtype=dtype)

    # Compute weight scales per Output Feature Dimension
    weight_scales = fc2.weight.abs().max(dim=0)[0]  # Max along Input Features Dimension
    weight_scales = weight_scales.clamp(min=1e-5)  # Avoid division by zero

    # print("Weight Scales Shape: ", weight_scales.shape)
    # print("Activation Scales Shape: ", act_scales.shape)

    # Calculate scales based on alpha
    scales = (
        (act_scales.pow(alpha) / weight_scales.pow(1 - alpha))
        .clamp(min=1e-5)
        .to(device=device, dtype=dtype)
    )
    # Weights are of shape (out_features, in_features), biases are of shape (out_features)
    # For FC1, we divide the weights by the scales along the out_features, and the biases by the scales
    # For FC2, we multiply the weights by the scales along the in_features

    fc1.weight.div_(scales.view(-1, 1))  
    if fc1.bias is not None:
        fc1.bias.div_(scales)

    fc2.weight.mul_(scales)


In [65]:
previous_layer = None
for name, current_layer in model.named_children():
    if name in activations:   
        if previous_layer is not None:
            print(f"Smoothing between Layers: {previous_layer} and {current_layer}")
            smooth_linear_linear(previous_layer, current_layer, activations[name], alpha=0.5)
        previous_layer = current_layer
        

Smoothing between Layers: Linear(in_features=784, out_features=128, bias=True) and Linear(in_features=128, out_features=64, bias=True)
Smoothing between Layers: Linear(in_features=128, out_features=64, bias=True) and Linear(in_features=64, out_features=10, bias=True)


 Now Lets Quantize

In [71]:
# Code borrowed from: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/fake_quant.py
# If you would like to understand how this can be implemented from scratch, please refer to my implementation here: https://github.com/satabios/quantization/quant/Weight_Only

@torch.no_grad()
def quantize_weight_per_channel_absmax(w, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    w.div_(scales).round_().mul_(scales)
    return w


@torch.no_grad()
def quantize_weight_per_tensor_absmax(w, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max()
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    w.div_(scales).round_().mul_(scales)
    return w


@torch.no_grad()
def quantize_activation_per_token_absmax(t, n_bits=8):
    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t


@torch.no_grad()
def quantize_activation_per_tensor_absmax(t, n_bits=8):
    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max()
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t


class W8A8Linear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        act_quant="per_token",
        quantize_output=False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            "weight",
            torch.randn(
                self.out_features,
                self.in_features,
                dtype=torch.float16,
                requires_grad=False,
            ),
        )
        if bias:
            self.register_buffer(
                "bias",
                torch.zeros(
                    (1, self.out_features), dtype=torch.float16, requires_grad=False
                ),
            )
        else:
            self.register_buffer("bias", None)

        if act_quant == "per_token":
            self.act_quant_name = "per_token"
            self.act_quant = partial(quantize_activation_per_token_absmax, n_bits=8)
        elif act_quant == "per_tensor":
            self.act_quant_name = "per_tensor"
            self.act_quant = partial(quantize_activation_per_tensor_absmax, n_bits=8)
        else:
            raise ValueError(f"Invalid act_quant: {act_quant}")

        if quantize_output:
            self.output_quant_name = self.act_quant_name
            self.output_quant = self.act_quant
        else:
            self.output_quant_name = "None"
            self.output_quant = lambda x: x

    def to(self, *args, **kwargs):
        super(W8A8Linear, self).to(*args, **kwargs)
        self.weight = self.weight.to(*args, **kwargs)
        if self.bias is not None:
            self.bias = self.bias.to(*args, **kwargs)
        return self

    @torch.no_grad()
    def forward(self, x):
        q_x = self.act_quant(x)
        y = torch.functional.F.linear(q_x, self.weight, self.bias)
        q_y = self.output_quant(y)
        return q_y

    @staticmethod
    def from_float(
        module, weight_quant="per_channel", act_quant="per_token", quantize_output=False
    ):
        assert isinstance(module, torch.nn.Linear)
        new_module = W8A8Linear(
            module.in_features,
            module.out_features,
            module.bias is not None,
            act_quant=act_quant,
            quantize_output=quantize_output,
        )
        if weight_quant == "per_channel":
            new_module.weight = quantize_weight_per_channel_absmax(
                module.weight, n_bits=8
            )  # use 8-bit integer for weight
        elif weight_quant == "per_tensor":
            new_module.weight = quantize_weight_per_tensor_absmax(
                module.weight, n_bits=8
            )
        else:
            raise ValueError(f"Invalid weight_quant: {weight_quant}")
        new_module.weight_quant_name = weight_quant
        if module.bias is not None:
            new_module.bias = module.bias
        return new_module

    def __repr__(self):
        return f"W8A8Linear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name}, act_quant={self.act_quant_name}, output_quant={self.output_quant_name})"


In [70]:
print(f"Original Model: {model}") 
print(" ---------------- Quantizing Model Post Smoothing ---------------- ")
for name, current_layer in model.named_children():
    if isinstance(current_layer, nn.Linear):
        model._modules[name] = W8A8Linear.from_float(current_layer, weight_quant="per_tensor", act_quant="per_tensor", quantize_output=True)
print(f"Quantized Model: {model}")

Original Model: NN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): W8A8Linear(784, 128, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (relu): ReLU()
  (linear2): W8A8Linear(128, 64, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (linear3): W8A8Linear(64, 10, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
)
 ---------------- Quantizing Model Post Smoothing ---------------- 
Quantized Model: NN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): W8A8Linear(784, 128, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (relu): ReLU()
  (linear2): W8A8Linear(128, 64, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (linear3): W8A8Linear(64, 10, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
)



<p style="text-align: center; font-weight: bold;">GPU Rich Bois!!... Peruse Further SmoothQuant on LLM</p>

In [1]:
import torch
import torch.nn as nn
from functools import partial

from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.bloom.modeling_bloom import BloomBlock
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import (
    MistralDecoderLayer,
    MistralRMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
    MixtralDecoderLayer,
    MixtralRMSNorm,
)
from transformers.models.opt.modeling_opt import (
    OPTAttention,
    OPTDecoderLayer,
    OPTForCausalLM,
)
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer

from transformers.models.opt.modeling_opt import OPTPreTrainedModel
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralPreTrainedModel
from transformers.models.mixtral.modeling_mixtral import MixtralPreTrainedModel
from transformers.models.falcon.modeling_falcon import FalconPreTrainedModel




<p style="text-align: center; font-weight: bold;">SmoothQuant Helpers</p>

In [2]:

def quantize_opt(
    model, weight_quant="per_tensor", act_quant="per_tensor", quantize_bmm_input=True
):
    from transformers.models.opt.modeling_opt import (
        OPTAttention,
        OPTDecoderLayer,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1 = W8A8Linear.from_float(
                m.fc1, weight_quant=weight_quant, act_quant=act_quant
            )
            m.fc2 = W8A8Linear.from_float(
                m.fc2, weight_quant=weight_quant, act_quant=act_quant
            )
        elif isinstance(m, OPTAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = W8A8Linear.from_float(
                m.q_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = W8A8Linear.from_float(
                m.k_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = W8A8Linear.from_float(
                m.v_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.out_proj = W8A8Linear.from_float(
                m.out_proj, weight_quant=weight_quant, act_quant=act_quant
            )
    return model


def quantize_llama_like(
    model, weight_quant="per_channel", act_quant="per_token", quantize_bmm_input=False
):
    from transformers.models.llama.modeling_llama import (
        LlamaAttention,
        LlamaMLP,
    )

    from transformers.models.mistral.modeling_mistral import (
        MistralAttention,
        MistralMLP,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, (LlamaMLP, MistralMLP)):
            m.gate_proj = W8A8Linear.from_float(
                m.gate_proj, weight_quant=weight_quant, act_quant=act_quant
            )
            m.up_proj = W8A8Linear.from_float(
                m.up_proj, weight_quant=weight_quant, act_quant=act_quant
            )
            m.down_proj = W8A8Linear.from_float(
                m.down_proj, weight_quant=weight_quant, act_quant=act_quant
            )
        elif isinstance(m, (LlamaAttention, MistralAttention)):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = W8A8Linear.from_float(
                m.q_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = W8A8Linear.from_float(
                m.k_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = W8A8Linear.from_float(
                m.v_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.o_proj = W8A8Linear.from_float(
                m.o_proj, weight_quant=weight_quant, act_quant=act_quant
            )
    return model


def quantize_mixtral(
    model, weight_quant="per_channel", act_quant="per_token", quantize_bmm_input=False
):
    from transformers.models.mixtral.modeling_mixtral import (
        MixtralAttention,
        MixtralSparseMoeBlock,
        MixtralBLockSparseTop2MLP,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, MixtralBLockSparseTop2MLP):
            m.w1 = W8A8Linear.from_float(
                m.w1, weight_quant=weight_quant, act_quant=act_quant
            )
            m.w2 = W8A8Linear.from_float(
                m.w2, weight_quant=weight_quant, act_quant=act_quant
            )
            m.w3 = W8A8Linear.from_float(
                m.w3, weight_quant=weight_quant, act_quant=act_quant
            )
        elif isinstance(m, MixtralAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = W8A8Linear.from_float(
                m.q_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.k_proj = W8A8Linear.from_float(
                m.k_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.v_proj = W8A8Linear.from_float(
                m.v_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.o_proj = W8A8Linear.from_float(
                m.o_proj, weight_quant=weight_quant, act_quant=act_quant
            )
        elif isinstance(m, MixtralSparseMoeBlock):
            m.gate = W8A8Linear.from_float(
                m.gate, weight_quant=weight_quant, act_quant=act_quant
            )
    return model


def quantize_falcon(
    model, weight_quant="per_channel", act_quant="per_token", quantize_bmm_input=True
):
    from transformers.models.falcon.modeling_falcon import (
        FalconAttention,
        FalconMLP,
    )

    for name, m in model.named_modules():
        if isinstance(m, FalconMLP):
            m.dense_h_to_4h = W8A8Linear.from_float(
                m.dense_h_to_4h, weight_quant=weight_quant, act_quant=act_quant
            )
            m.dense_4h_to_h = W8A8Linear.from_float(
                m.dense_4h_to_h, weight_quant=weight_quant, act_quant=act_quant
            )
        elif isinstance(m, FalconAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.query_key_value = W8A8Linear.from_float(
                m.query_key_value,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
            )
            m.dense = W8A8Linear.from_float(
                m.dense, weight_quant=weight_quant, act_quant=act_quant
            )
    return model


def quantize_model(
    model, weight_quant="per_channel", act_quant="per_token", quantize_bmm_input=False
):


    if isinstance(model, OPTPreTrainedModel):
        return quantize_opt(
            model,
            weight_quant=weight_quant,
            act_quant=act_quant,
            quantize_bmm_input=quantize_bmm_input,
        )
    elif isinstance(model, (LlamaPreTrainedModel, MistralPreTrainedModel)):
        return quantize_llama_like(
            model,
            weight_quant=weight_quant,
            act_quant=act_quant,
            quantize_bmm_input=quantize_bmm_input,
        )
    elif isinstance(model, MixtralPreTrainedModel):
        return quantize_mixtral(
            model,
            weight_quant=weight_quant,
            act_quant=act_quant,
            quantize_bmm_input=quantize_bmm_input,
        )
    elif isinstance(model, FalconPreTrainedModel):
        return quantize_falcon(
            model,
            weight_quant=weight_quant,
            act_quant=act_quant,
            quantize_bmm_input=quantize_bmm_input,
        )
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")


In [3]:

@torch.no_grad()
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
    if not isinstance(fcs, list):
        fcs = [fcs]
    assert isinstance(ln, nn.LayerNorm)
    for fc in fcs:
        assert isinstance(fc, nn.Linear)
        assert ln.weight.numel() == fc.in_features == act_scales.numel()

    device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
    act_scales = act_scales.to(device=device, dtype=dtype)
    weight_scales = torch.cat(
        [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0
    )
    weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)

    scales = (
        (act_scales.pow(alpha) / weight_scales.pow(1 - alpha))
        .clamp(min=1e-5)
        .to(device)
        .to(dtype)
    )

    ln.weight.div_(scales)
    ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))


@torch.no_grad()
def smooth_ln_fcs_llama_like(ln, fcs, act_scales, alpha=0.5):
    if not isinstance(fcs, list):
        fcs = [fcs]
    assert isinstance(ln, (LlamaRMSNorm, MistralRMSNorm, MixtralRMSNorm))
    for fc in fcs:
        assert isinstance(fc, nn.Linear)
        assert ln.weight.numel() == fc.in_features == act_scales.numel()
    device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
    act_scales = act_scales.to(device=device, dtype=dtype)
    weight_scales = torch.cat(
        [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0
    )
    weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
    scales = (
        (act_scales.pow(alpha) / weight_scales.pow(1 - alpha))
        .clamp(min=1e-5)
        .to(device)
        .to(dtype)
    )

    ln.weight.div_(scales)
    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))


@torch.no_grad()
def smooth_lm(model, scales, alpha=0.5):
    for name, module in model.named_modules():
        if isinstance(module, OPTDecoderLayer):
            attn_ln = module.self_attn_layer_norm
            qkv = [
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj,
            ]
            qkv_input_scales = scales[name + ".self_attn.q_proj"]
            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)

            ffn_ln = module.final_layer_norm
            fc1 = module.fc1
            fc1_input_scales = scales[name + ".fc1"]
            smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
        elif isinstance(module, BloomBlock):
            attn_ln = module.input_layernorm
            qkv = module.self_attention.query_key_value
            qkv_input_scales = scales[name + ".self_attention.query_key_value"]
            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)

            ffn_ln = module.post_attention_layernorm
            fc1 = module.mlp.dense_h_to_4h
            fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"]
            smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
        elif isinstance(module, FalconDecoderLayer):
            qkv = module.self_attention.query_key_value
            qkv_input_scales = scales[name + ".self_attention.query_key_value"]
            fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"]
            fc1 = module.mlp.dense_h_to_4h

            if (
                not module.config.new_decoder_architecture
                and module.config.parallel_attn
            ):
                attn_ln = module.input_layernorm
                smooth_ln_fcs(attn_ln, [qkv, fc1], qkv_input_scales, alpha)
            else:
                attn_ln = (
                    module.ln_attn
                    if module.config.new_decoder_architecture
                    else module.input_layernorm
                )
                ffn_ln = (
                    module.ln_mlp
                    if module.config.new_decoder_architecture
                    else module.post_attention_layernorm
                )
                smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
                smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
        elif isinstance(module, (LlamaDecoderLayer, MistralDecoderLayer)):
            attn_ln = module.input_layernorm  # attention forward norm
            qkv = [
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj,
            ]

            qkv_input_scales = scales[name + ".self_attn.q_proj"]
            smooth_ln_fcs_llama_like(attn_ln, qkv, qkv_input_scales, alpha)

            ffn_ln = module.post_attention_layernorm  # feed forward norm
            fcs = [module.mlp.gate_proj, module.mlp.up_proj]
            fcs_input_scales = scales[name + ".mlp.gate_proj"]

            smooth_ln_fcs_llama_like(ffn_ln, fcs, fcs_input_scales, alpha)
        elif isinstance(module, MixtralDecoderLayer):
            attn_ln = module.input_layernorm  # attention forward norm
            qkv = [
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj,
            ]

            qkv_input_scales = scales[name + ".self_attn.q_proj"]
            smooth_ln_fcs_llama_like(attn_ln, qkv, qkv_input_scales, alpha)

            ffn_ln = module.post_attention_layernorm  # feed forward norm
            fcs = [module.block_sparse_moe.gate]
            for expert in module.block_sparse_moe.experts:
                fcs.append(expert.w1)
                fcs.append(expert.w3)
            fcs_input_scales = scales[name + ".block_sparse_moe.gate"]

            smooth_ln_fcs_llama_like(ffn_ln, fcs, fcs_input_scales, alpha)

In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. 
The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the "Last Token Prediction Accuracy" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the lm-eval-harness to obtain the "Last Word Prediction Accuracy" for the LAMBADA dataset, which is the reported metric in our paper.


<p style="text-align: center; font-weight: bold;">SmoothQuant</p>

In [4]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples["text"])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type="torch", columns=["input_ids"])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch["input_ids"].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc

In [5]:
from datasets import load_dataset
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
dataset = load_dataset("lambada", split="validation[:1000]")
evaluator = Evaluator(dataset, tokenizer, "cuda")

In [6]:
model_fp16 = OPTForCausalLM.from_pretrained(
    "facebook/opt-1.3b", torch_dtype=torch.float16, device_map="auto"
)
model_fp16

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): La

In [7]:
acc_fp16 = evaluator.evaluate(model_fp16)
print(f"Original model (fp16) accuracy: {acc_fp16}")

Original model (fp16) accuracy: 0.722


We then quantize the model to W8A8 and check the performance.

Naive W8A8 Quantized Model Accuracy

In [8]:
model_w8a8 = quantize_opt(model_fp16)
print(model_w8a8)

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (v_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (q_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (out_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, ele

In [9]:
model_fp16

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (v_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (q_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (out_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, ele

In [10]:
model_w8a8

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (v_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (q_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (out_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, ele

In [11]:
del model_fp16
acc_w8a8 = evaluator.evaluate(model_w8a8)
print(f"Naive W8A8 quantized model accuracy: {acc_w8a8}")

Naive W8A8 quantized model accuracy: 0.692


We can see there is a significant accuracy drop. This is consistent with LLM.int8()'s finding: when the model size increases larger than 6.7B, systematic outliers will emerge in activations, which makes fully INT8 quantization impossible.

SmoothQuant W8A8 Quantized Model Accuracy
Let's smooth the model, quantize it, and check the performance! In ../act_scales, we provide the activation scales for OPT and BLOOM models. You can also use this notebook to test quantizing those models.

In [13]:
model = OPTForCausalLM.from_pretrained(
    "facebook/opt-1.3b", torch_dtype=torch.float16, device_map="auto"
)
act_scales = torch.load("opt-1.3b_smq_scales.pt", weights_only=False)
smooth_lm(model, act_scales, 0.5)
model_smoothquant_w8a8 = quantize_opt(model)
print(model_smoothquant_w8a8)

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (v_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (q_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
            (out_proj): W8A8Linear(2048, 2048, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, ele

We can see the smoothed model has the same accuracy as the FP16 model. This is because SmoothQuant smooths the outliers in activations and moves the quantization difficulty from activations to weights.



In [14]:
acc_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)
print(f"SmoothQuant W8A8 quantized model accuracy: {acc_smoothquant_w8a8}")

SmoothQuant W8A8 quantized model accuracy: 0.707
