## Reference:
- https://learn.deeplearning.ai/courses/quantization-in-depth

# Overview: Custom Build an 8-Bit Quantizer

This notebook shows how to:
- Build a custom 8-bit quantizer `W8A16LinearLayer()`
- Compress any model in 8-bit precision using this.
- Replace Pytorch layers with quantized layers
- Quantize any Open source Pytorch model
- Load quantized weights from Huggingface hub in memory efficient way

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Check quantize error

In [2]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max


def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                     scale=scale,
                   # in symmetric quantization zero point is = 0    
                                                    zero_point=0,
                                                      dtype=dtype)
    
    return quantized_tensor, scale

def linear_q_with_scale_and_zero_point(
    r_tensor, scale, zero_point, dtype=torch.int8):
    """
    Performs simple linear quantization given
    the scale and zero-point.
    """

    # scale tensor and add the zero point
    scaled_and_shifted_tensor = r_tensor / scale + zero_point

    # round the tensor 
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    # we need to clamp to the min/max value of the specified dtype
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)
    return q_tensor


In [3]:
hs = torch.tensor([1, 2, 3], dtype=torch.float32)
w = torch.randn((3,3))
q_w, s_w  = linear_q_symmetric(w)

# Using Quantized weight
dequantized_weight = q_w.to(torch.float32) * s_w + 0 # z_w
output = torch.nn.functional.linear(hs, dequantized_weight)
print(output)

# Using Original non-quuanized weight
output = torch.nn.functional.linear(hs, w)
print(output)

tensor([-1.1628, -0.5933, -2.2189])
tensor([-1.1603, -0.5911, -2.2290])


### `w8_a16_forward` Function

-
```Python
W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None
                    
```
- Cast the 8-bit `weights` to the same data type as the `input`, "casted weights",
- keeping the "casted weights" in the same range as before, [-128, 127]
- Output: $$(({inputs} \cdot \text{"casted weights"}) * {scale}) + {bias}$$ 

In [4]:
w = torch.randint(-128, 127, (32, 16))
w_int8 = w.to(torch.int8)
hs = torch.randn((1, 16), dtype=torch.bfloat16)

scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [5]:
(F.linear(hs, w_int8.to(hs.dtype)) * scales) + bias

tensor([[ 110.0000,  -53.5000, -300.0000,  812.0000,   40.0000,   67.0000,
           76.5000,   -7.9375, -164.0000, -122.0000,    7.8125,  -62.2500,
          183.0000,  154.0000,    6.0000, -344.0000,  -60.0000,  -13.0000,
         -128.0000,   11.1250,   58.7500,  468.0000,  346.0000,   36.0000,
            1.3672, -212.0000, -129.0000,  -20.2500, -256.0000,  548.0000,
           36.0000,  159.0000]], dtype=torch.bfloat16)

In [6]:
def w8_a16_forward(weight, input, scales, bias=None):    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    if bias is not None:
        output = output + bias      
    return output

In [7]:
print("With bias:\n\n", 
      w8_a16_forward(w_int8, hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(w_int8, hs, scales))

With bias:

 tensor([[ 110.0000,  -53.5000, -300.0000,  812.0000,   40.0000,   67.0000,
           76.5000,   -7.9375, -164.0000, -122.0000,    7.8125,  -62.2500,
          183.0000,  154.0000,    6.0000, -344.0000,  -60.0000,  -13.0000,
         -128.0000,   11.1250,   58.7500,  468.0000,  346.0000,   36.0000,
            1.3672, -212.0000, -129.0000,  -20.2500, -256.0000,  548.0000,
           36.0000,  159.0000]], dtype=torch.bfloat16)

Without bias:

 tensor([[ 109.5000,  -52.5000, -300.0000,  812.0000,   39.7500,   66.5000,
           77.0000,   -6.8125, -164.0000, -123.0000,    8.3125,  -61.7500,
          183.0000,  155.0000,    5.2500, -346.0000,  -60.7500,  -13.0625,
         -126.0000,    9.9375,   59.2500,  468.0000,  346.0000,   35.7500,
            1.6094, -213.0000, -129.0000,  -19.3750, -254.0000,  548.0000,
           35.0000,  159.0000]], dtype=torch.bfloat16)


### `W8A16LinearLayer`

- This is `init` signature of [PyTorch Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear):
```Python
def __init__(self, in_features, out_features, bias=True,
             device=None, dtype=None)
```

In [8]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        # This will error on layer initialization
        # RuntimeError :  Only Tensors of floating point and complex dtype can require gradients 
        #self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
        #                             ).to(dtype=torch.int8))

        # This says dont compute gradients on them.       
        self.register_buffer("int8_weights",
                            torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))        
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32) # for stability
        
        # Per channel since max from last dim 
        scales = w_fp32.abs().max(dim=-1).values / 127 
        scales = scales.to(weights.dtype) # same type as weights
        
        # linear quantization
        int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8) 
        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)      

In [9]:
# Test dtypes
module = W8A16LinearLayer(16, 32) # input_hidden_size, output_hidden_size
print(module.int8_weights.shape)
print(module.scales.shape)

# Check dtypes
dummy_hidden_states = torch.randn(1, 6, 16) # bs, seq_len, input_hidden_size
dummy_output_states = module(dummy_hidden_states)
dummy_output_states.shape, dummy_output_states.dtype

torch.Size([32, 16])
torch.Size([32])


(torch.Size([1, 6, 32]), torch.float32)

In [10]:
module = W8A16LinearLayer(4, 8)
print(module.scales, module.scales.shape)
print("Weights at init:\n" , module.int8_weights, module.int8_weights.shape)

tensor([-0.3539, -0.5164,  0.7794,  1.2950, -1.1936,  1.1703, -1.0569,  1.5416]) torch.Size([8])
Weights at init:
 tensor([[ -98,   59,   44,   21],
        [  42,   58,   54, -118],
        [  26,   37, -117,  -53],
        [ -92, -127,  -28,  111],
        [ -89,   92,  -51,  -11],
        [  32,   99,  101,   69],
        [  32, -127,   24,   42],
        [ -44, -107,  -62,  108]], dtype=torch.int8) torch.Size([8, 4])


In [11]:
w = torch.randn((4, 8), dtype=torch.bfloat16)
print("Orig Weights:\n" , w, w.shape)

Orig Weights:
 tensor([[-0.0693,  0.1245, -1.1562,  0.8398,  1.3828, -0.1406,  0.1406, -0.2793],
        [-0.7891,  0.0123, -1.2266, -0.6914, -0.4023,  0.2441, -0.4766,  0.1001],
        [ 0.7812, -0.6797, -0.2695, -1.0156, -2.5781, -0.5312,  0.1348, -2.2344],
        [-0.6875, -2.4062, -0.8516, -0.9727, -1.1484,  0.5781, -0.8242, -0.1738]],
       dtype=torch.bfloat16) torch.Size([4, 8])


In [12]:
module.quantize(w)
print("Quant Weights:\n" , module.int8_weights,  module.int8_weights.shape)

Quant Weights:
 tensor([[  -6,   11, -106,   78, -128,  -13,   13,  -26],
        [ -82,    1, -127,  -72,  -42,   25,  -50,   10],
        [  38,  -34,  -13,  -50, -127,  -26,    7, -110],
        [ -36, -127,  -45,  -52,  -61,   30,  -44,   -9]], dtype=torch.int8) torch.Size([4, 8])


In [13]:
### dequantized weights
wdeq = module.int8_weights * module.scales.unsqueeze(1)
print("Dequant Weights:\n",wdeq)

Dequant Weights:
 tensor([[-0.0654,  0.1196, -1.1484,  0.8477, -1.3906, -0.1416,  0.1416, -0.2832],
        [-0.7891,  0.0096, -1.2266, -0.6953, -0.4043,  0.2412, -0.4824,  0.0967],
        [ 0.7695, -0.6875, -0.2637, -1.0156, -2.5781, -0.5273,  0.1416, -2.2344],
        [-0.6797, -2.4062, -0.8516, -0.9844, -1.1562,  0.5664, -0.8320, -0.1699]],
       dtype=torch.bfloat16)


In [14]:
# Quantization error
(w - wdeq).abs().mean()

tensor(0.0913, dtype=torch.bfloat16)

### Replace Linear layers in Pytorch module + Quantization of weights

In [15]:
def replace_linear_with_target_and_quantize(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, 
                     target_class, module_name_to_exclude)

In [16]:
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

In [17]:
model_1 = DummyModel()
model_2 = DummyModel()
model_3 = DummyModel()

In [18]:
replace_linear_with_target_and_quantize(model_1, W8A16LinearLayer, ["lm_head"])
print(model_1)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)


In [19]:
replace_linear_with_target_and_quantize(model_2, W8A16LinearLayer, [])
print(model_2)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)


In [20]:
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, ["lm_head"])
print(model_3)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)


### Test the Implementation on Various LLMs

- Text generation model: [Salesforce/codegen-350M-mono](https://huggingface.co/Salesforce/codegen-350M-mono)

In [21]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_id = "./models/Salesforce/codegen-350M-mono"

model = AutoModelForCausalLM.from_pretrained(model_id, 
                                    torch_dtype=torch.bfloat16, 
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

In [23]:
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'def hello_world():\n    print("Hello World")\n\nhello_world()\n\n# 파'}]


In [24]:
print("Model before:\n\n", model)

Model before:

 CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): CodeGenMLP(
          (fc_in): Linear(in_features=1024, out_features=4096, bias=True)
          (fc_out): Linear(in_features=4096, out_features=1024, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bi

In [25]:
previous_memory_footprint = model.get_memory_footprint()
print("Footprint of the model in MBs: ", 
      previous_memory_footprint/1e+6)

Footprint of the model in MBs:  797.310976


In [26]:
replace_linear_with_target_and_quantize(model, 
                                        W8A16LinearLayer, ["lm_head"])

In [27]:
pipe.model

CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): W8A16LinearLayer()
          (out_proj): W8A16LinearLayer()
        )
        (mlp): CodeGenMLP(
          (fc_in): W8A16LinearLayer()
          (fc_out): W8A16LinearLayer()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)

In [28]:
print(pipe("def hello_world():", max_new_tokens=20, 
           do_sample=False)[0]["generated_text"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


def hello_world():
    print("Hello World")

# hello_world()

# def hello_


In [29]:
new_footprint = model.get_memory_footprint()
print("Footprint of the model in MBs: ", 
      new_footprint/1e+6)

Footprint of the model in MBs:  546.021376


In [30]:
### Memory saved
print("Memory saved in MBs: ", 
      (previous_memory_footprint - new_footprint)/1e+6)

Memory saved in MBs:  251.2896


## Memory Efficient Model Loading using `meta`

- First the weights are quantized & saved

In [31]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "./models/facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [32]:
replace_linear_with_target_and_quantize(model, 
                             W8A16LinearLayer, 
                                   ["lm_head"])

In [33]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): W8A16LinearLayer()
            (v_proj): W8A16LinearLayer()
            (q_proj): W8A16LinearLayer()
            (out_proj): W8A16LinearLayer()
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): W8A16LinearLayer()
          (fc2): W8A16LinearLayer()
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)

In [34]:
quantized_state_dict = model.state_dict()
torch.save(quantized_state_dict, "quantized_state_dict.pth")

- Optionally push them to Huggingface hub

```Python
from huggingface_hub import HfApi, create_repo

YOUR_HF_USERNAME = ""
your_repo_id = f"{YOUR_HF_USERNAME}/opt-125m-quantized-dlai"

api = HfApi()

# create_repo(your_repo_id)

api.upload_file(
 path_or_fileobj="quantized_state_dict.pth",
 path_in_repo="quantized_state_dict.pth",
 repo_id=your_repo_id
)
```

- When you load quantized weights next time, use `meta` device

In [35]:
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig

model_id = "./models/facebook/opt-125m"
config = AutoConfig.from_pretrained(model_id)

# Loads only skeleton of the model. Weights are not loaded
with torch.device("meta"):
  model = OPTForCausalLM(config)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [36]:
# Weights are not loaded yet
for param in model.parameters():
  print(param)

Parameter containing:
tensor(..., device='meta', size=(50272, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2050, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_

In [37]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), ep

In [40]:
def replace_linear_with_target(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target(child, 
                     target_class, module_name_to_exclude)

In [41]:
# Update the skeleton
replace_linear_with_target(model, W8A16LinearLayer, ["lm_head"])

In [42]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): W8A16LinearLayer()
            (v_proj): W8A16LinearLayer()
            (q_proj): W8A16LinearLayer()
            (out_proj): W8A16LinearLayer()
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): W8A16LinearLayer()
          (fc2): W8A16LinearLayer()
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)

In [43]:
from huggingface_hub import hf_hub_download

# Quantized weights from huggingface
state_dict_cache_path = hf_hub_download(
    "ybelkada/opt-125m-quantized-dlai",
    "quantized_state_dict.pth"
)

# 125MB parameter model => 166MB size
# 8bit (1byte) => 125MB + 16bit scales etc.

In [44]:
state_dict = torch.load(state_dict_cache_path)

In [45]:
# Loads quantized weights into the model
model.load_state_dict(state_dict, strict=True, assign=True)

<All keys matched successfully>

In [46]:
from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am", max_new_tokens=40)

[{'generated_text': 'Hello today I am a new member of the team.\nI am a new member of the team.\nI am a new member of the team.\nI am a new member of the team.\nI am'}]

In [47]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am giving a course about", max_new_tokens=10)

[{'generated_text': 'Hello today I am giving a course about the new technology of the future.\nI am'}]