<a href="https://colab.research.google.com/github/srimoyee1212/BERT-Quantization/blob/main/LLM_Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
class QuantizedLinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias = True, dtype=torch.float32):
        super().__init__()

        self.register_buffer("weight", torch.randint(-8, 7, (out_features, in_features)).to(torch.int8).to(device))
        self.register_buffer("scale", torch.randn((out_features), dtype=dtype).to(device))

        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype).to(device))
        else:
            self.bias = None

# For 4-bit quantization
    def quantize(self, weight):
        weight_f32 = weight.clone().to(torch.float32).to(device)
        scale = weight_f32.abs().max(dim=-1).values/7
        scale = scale.to(weight.dtype)

        quantized_weight = torch.clamp(torch.round(weight/scale.unsqueeze(1)), -8, 7).to(torch.int8).to(device)
        # Further quantized the weight to 4-bit by PACKING THE WEIGHT
        quantized_weight4bit = pack_weights(quantized_weight)

        self.weight = quantized_weight4bit
        self.scale = scale
    def forward(self, input):
        # unpack the self.weight first
        unpacked_weight = unpack_weights(self.weight)
        print(unpacked_weight.shape)
        output = F.linear(input, unpacked_weight.to(input.dtype)) * self.scale
        if self.bias is not None:
            output = output + self.bias
        return output

def replace_linearlayer(base_model, quantizer_class, exception_list, quantized=True):
    for name, child in base_model.named_children():
        if isinstance(child, nn.Linear) and not any([x == name for x in exception_list]):
            old_bias = child.bias
            old_weight = child.weight
            in_features = child.in_features
            out_features = child.out_features

          # intantiate a quantizer class layer
            if quantized:
                quantizer_layer = quantizer_class(in_features, out_features, old_bias is not None, old_weight.dtype).to(device)
            else:
                in_features = in_features//2
                quantizer_layer = quantizer_class(in_features, out_features, old_bias is not None, old_weight.dtype).to(device)

          # replace the name with quantizer_module layer
            setattr(base_model, name, quantizer_layer)

          # since the base_model name is now replaced with quantizer_module, we can call its quantize function to quantize the old_weight. the weight of the quantizer layer is a quantized weight with int8 type
            if quantized:
                getattr(base_model, name).quantize(old_weight)

          # we can also update the quantizer module bias with the old_bias if it is not none
            if old_bias is not None:
                getattr(base_model, name).bias = old_bias

        # if the child has further any sub linear layer, we can invoke the function again and loop inside the child. Pass the child in place of base_model
        else:
            replace_linearlayer(child, quantizer_class, exception_list, quantized=quantized)


In [None]:
def pack_weights(quantized_weight8bit):
  # given a tensor with 2bit encoded value encoded_weight4bit. total number of value in encoded_weight4bit * bits(2bit encoded in this case). that should be divisible by 8.
  # why divisible by 8? becuase we're storing number of 2bit encoded value in new 8bit tensor which sholud fit. PyTorch only support int8 precisio, not int2 or 4.
    bits = 4
    if quantized_weight8bit.shape[-1] * bits % 8 != 0:
        raise ValueError("encoded_weight4bit.shape[0] * bits shoul be divisible by 8")

  # total number of int8 values after int2 are packed in
    num_values = quantized_weight8bit.shape[-1] * bits // 8

  # total number of 2-bit value within a single int8 packed tensor. num_values is total number of int8 packed tensor.
    num_steps = 8 // bits    #8 is the total number of bit in 8bit and dividing by bits can give the total number of individual 2-bit encoded value located/packed inside that single int8 value tensor
    packed_weights = torch.zeros((quantized_weight8bit.shape[0], num_values), dtype=torch.int8).to(device)
    weight_index = 0

    for row in range(quantized_weight8bit.shape[0]):
        weight_index = 0
        for i in range(num_values):
            for j in range(num_steps):
                if j==0 and quantized_weight8bit[row,weight_index] < 0:
                    encoded_weight4bit_zero = 0 # First value of packed_tensor shouldn't be negative - for now can't find logic yet
                    packed_weights[row, i] |= encoded_weight4bit_zero << bits * j
                else:
                    packed_weights[row, i] |= quantized_weight8bit[row,weight_index] << bits * j
                weight_index += 1
    return packed_weights

def unpack_weights(packed_weights):
  # how many 2-bit value are there in the entire packed_tensor
  # first calcualte totals bits = packed_tensor.shape[0](total no of values in the packed tensor ) * 8 (each value is a unsigned 8bit tensor)
  # then divide it by unpacked bits or original bits that we encoded - bits
    bits = 4
    packed_weights = packed_weights.to(torch.int8).to(device)
    num_values = packed_weights.shape[-1] * 8 // bits

  # number of steps is how many encoded bits value is in the single packed tensor value
    num_steps = 8 // bits

  #lets initialized a unpacked_tensor with zero and later to be update with actual 2bit encoded value
  # first we'll just extract encoded value in int8 and later we'll extract only the 2-bit part, we'll see how
    unpacked_weights = torch.zeros((packed_weights.shape[0],num_values), dtype = torch.int8).to(device)

    for row in range(packed_weights.shape[0]):
        unpacked_index = 0
        for i in range(packed_weights.shape[-1]):
            for j in range(num_steps):
                unpacked_weights[row, unpacked_index] |= packed_weights[row, i] >> bits * j
                unpacked_index += 1
            mask = 2**bits - 1
            unpacked_weights[row] &= mask
            unpacked_weights[row] = optimize_unpacked_weights(unpacked_weights[row])
    return unpacked_weights

def optimize_unpacked_weights(unpacked_weights):
    updated_unpacked_weights = torch.zeros(unpacked_weights.shape[0], dtype=torch.int8).to(device)
    for i in range(unpacked_weights.shape[0]):
        a_binary = format(unpacked_weights[i].item(), '04b')
        a3=int(a_binary[0])
        a2=int(a_binary[1])
        a1=int(a_binary[2])
        a0=int(a_binary[3])

        if a3 == 1:
            updated_unpacked_weights[i] = -a3*pow(2, 3) + a2*pow(2, 2) + a1*pow(2, 1) + a0*pow(2, 0)
        else:
            updated_unpacked_weights[i]= a3*pow(2, 3) + a2*pow(2, 2) + a1*pow(2, 1) + a0*pow(2, 0)
    return updated_unpacked_weights

In [None]:
# Calculate model size function
def calculate_model_size(model):
    total_size = 0
    for param in model.parameters():
        total_size += param.nelement() * param.element_size()
    return total_size / (1024 ** 2)  # Convert bytes to megabytes

# Print original model size
original_model_size = calculate_model_size(model)
print(f"Original model size: {original_model_size:.2f} MB")

model_memory_size_before_quantization = model.get_memory_footprint()
print(f"Total memory size before quantization (in GB): {model_memory_size_before_quantization / 1e+9}")

# Replace linear layers in BERT model with quantized linear layers
replace_linearlayer(model, QuantizedLinearLayer, ["lm_head"], quantized=True)

# Calculate the size of the quantized model
quantized_model_size = calculate_model_size(model)
print(f"Quantized model size: {quantized_model_size:.2f} MB")

model_memory_size_after_quantization = model.get_memory_footprint()
print(f"Total memory size after quantization (in GB): {model_memory_size_after_quantization / 1e+9}")

# Verify by running a forward pass with quantized model
input_ids = tokenizer("This is a test sentence.", return_tensors='pt')['input_ids'].to(device)
outputs = model(input_ids)
print(outputs)

Original model size: 417.64 MB
Total memory size before quantization (in GB): 0.437937152
Quantized model size: 91.39 MB
Total memory size after quantization (in GB): 0.13893632
