# Optimizing LLMs through Quantization

Author - Sri Raghu Malireddi

Date - October 3, 2024

# 1. Introduction

Quantization is a process that allows models to be deployed efficiently by reducing the precision of the model's parameters and activations. This is especially important in **Large Language Models (LLMs)**, which contain millions or billions of parameters.

LLMs such as GPT family of models, Llama, Mixtral, Phi etc., require large amounts of memory and computational power. Quantization enables us to reduce the size of the model and perform faster computations, all while attempting to maintain the accuracy of the original model.

In this notebook, we'll explore the following topics:
- Memory Efficiency
- Inference Speed-Up
- Power Efficiency
- Deploying LLMs on Edge Devices
- Maintaining Acceptable Accuracy
- Scalability and Cost

# 2. Memory Efficiency

LLMs contain billions of parameters, which can make them difficult to store and run on standard hardware. Quantization can dramatically reduce the memory footprint of these models by reducing the precision of the model parameters.

## Example: Memory Savings with Quantization

Let’s take an example of how much memory can be saved by converting from 32-bit floating-point precision (FP32) to 8-bit integer precision (INT8).

### Formula for Memory Savings:

- FP32: 4 bytes per parameter
- INT8: 1 byte per parameter

For a model with 1 billion parameters:

In [1]:
# Example of memory usage comparison for a 1 billion parameter model
parameters = 1e9  # 1 billion parameters

# Memory in FP32 (4 bytes per parameter)
fp32_memory = parameters * 4 / 1e9  # in GB
print(f"Memory usage in FP32: {fp32_memory} GB")

# Memory in INT8 (1 byte per parameter)
int8_memory = parameters * 1 / 1e9  # in GB
print(f"Memory usage in INT8: {int8_memory} GB")

Memory usage in FP32: 4.0 GB
Memory usage in INT8: 1.0 GB


# 3. Inference Speed-Up

Quantization not only saves memory but also speeds up inference, which is crucial for real-time applications of LLMs. Integer operations (INT8) are computationally cheaper and faster compared to floating-point operations (FP32).

## PyTorch Example: INT8 vs FP32

We can use PyTorch to simulate how an LLM would perform with FP32 and quantized INT8 formats.

### Convert a Floating-Point Model to INT8 Using PyTorch:

In [2]:
import torch
import torch.quantization

# Simulate a simple FP32 model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
    
    def forward(self, x):
        x = self.fc1(x)
        return self.fc2(x)

# Instantiate the model
model_fp32 = SimpleModel()

# Quantize the model to INT8
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32, {torch.nn.Linear}, dtype=torch.qint8
)

# Compare memory and speed
print(f"Original FP32 Model Size: {model_fp32.fc1.weight.element_size() * model_fp32.fc1.weight.nelement()} bytes")
print(f"Quantized INT8 Model Size: {model_int8.fc1.weight().element_size() * model_int8.fc1.weight().nelement()} bytes")

Original FP32 Model Size: 32768 bytes
Quantized INT8 Model Size: 8192 bytes


# 4. Power Efficiency

One of the lesser-known benefits of quantization is the significant reduction in power consumption, particularly important for deploying LLMs on devices with limited power like mobile phones or IoT devices.

Quantized models, particularly INT8, use much less power due to reduced computational overhead and fewer memory access operations. This makes it more feasible to run LLMs on such power-constrained devices.

# 5. Deploying LLMs on Edge Devices

LLMs are typically deployed in cloud environments due to the immense resources required to run them. However, quantization makes it possible to deploy them on **edge devices**, like smartphones, where memory and computational resources are limited.

Quantizing large models can reduce their footprint to the point where they can run efficiently on such devices without compromising user experience.

# 6. Maintaining Acceptable Accuracy

Quantization has traditionally been viewed as a trade-off between performance and accuracy. However, techniques like **Quantization-Aware Training (QAT)** allow models to retain most of their accuracy while benefiting from the efficiency of quantization.

## Example: Quantization-Aware Training (QAT)

In QAT, the model is trained with quantization effects simulated during training, allowing it to learn to compensate for the errors introduced by quantization.

# 7. Scalability and Cost

Quantization allows LLMs to be scaled to even larger sizes without an exponential increase in cost and infrastructure requirements. With quantization, fewer GPUs are needed, inference can happen faster, and memory consumption is much more manageable.

In large-scale deployments, these savings translate into significant cost reductions.

# 8. Types of Quantization - Downcasting

## Understanding Floating Point Types: FP32, FP16, and BF16

When performing computations in neural networks and other machine learning tasks, the precision of floating-point numbers plays an important role. Different formats like FP32, FP16, and BF16 represent floating-point numbers with varying levels of precision, which impacts performance, memory usage, and the accuracy of computations.

### 1. **FP32 (Single-Precision Floating-Point)**

- **Bits**: 32 bits
- **Structure**: 
  - 1 bit for sign
  - 8 bits for exponent
  - 23 bits for the fraction (mantissa)
- **Range**: Supports a wide dynamic range with a large number of significant digits.
- **Usage**: FP32 is commonly used in most machine learning models as the default precision for training and inference. It provides high accuracy at the cost of memory and computational speed.

**Advantages**:
  - High precision for both very small and very large numbers.
  - Suitable for tasks requiring higher numerical stability.

### 2. **FP16 (Half-Precision Floating-Point)**

- **Bits**: 16 bits
- **Structure**:
  - 1 bit for sign
  - 5 bits for exponent
  - 10 bits for the fraction
- **Range**: Reduced dynamic range compared to FP32 due to fewer bits allocated for the exponent and fraction.
- **Usage**: FP16 is often used in scenarios where memory and speed are crucial, such as on-device machine learning or inference tasks, but with potential trade-offs in accuracy.

**Advantages**:
  - Requires half the memory of FP32.
  - Faster computations, ideal for low-power devices.

**Limitations**:
  - Limited precision, which can lead to accuracy loss in complex tasks.
  
### 3. **BF16 (Bfloat16 Floating-Point)**

- **Bits**: 16 bits
- **Structure**:
  - 1 bit for sign
  - 8 bits for exponent (same as FP32)
  - 7 bits for the fraction
- **Range**: Similar dynamic range to FP32 due to the same 8-bit exponent. However, fewer bits in the fraction reduce the precision of the mantissa.
- **Usage**: BF16 is commonly used in large-scale training, particularly on TPUs and GPUs. It offers a good balance between the range of FP32 and the computational efficiency of FP16.

**Advantages**:
  - Wider dynamic range like FP32 but with less precision.
  - More efficient in terms of memory and computation than FP32, without sacrificing as much accuracy as FP16.

### Visual Representation

Here’s a comparison of how these formats allocate bits:

![Floating Point Formats](assets/float_types.png)
<small><i>Source: https://cerebras.ai/machine-learning/to-bfloat-or-not-to-bfloat-that-is-the-question/</i></small>

- FP16 and BF16 use 16 bits but allocate them differently.
- FP32 uses more bits for the fraction, allowing for higher precision.
- BF16 and FP16 save memory and computation at the cost of precision, but BF16 maintains a broader range due to its 8-bit exponent.

### Summary of Key Differences:

| Format  | Total Bits | Exponent Bits | Fraction Bits | Use Case                                 |
|---------|------------|---------------|---------------|------------------------------------------|
| FP32    | 32         | 8             | 23            | High precision, training large models    |
| FP16    | 16         | 5             | 10            | Memory-efficient, on-device inference    |
| BF16    | 16         | 8             | 7             | Efficient training on TPUs/GPUs          |

When choosing between these formats, it's essential to consider the trade-offs between precision and computational efficiency, as some tasks can tolerate less precision, while others require the stability offered by FP32.


## 8.1. Lab - Quantization Datatypes

In this lab, lets looks at the basic data types and their memory footprints. PyTorch support various datatypes. For this lab, we will specifically focus on 32 bit and 16 bit Floats, Brain Float 16, 8 bit unsigned and signed integers.

Also a simplest form of quantization - downcasting from FP32 to FP16 and BF16.

In [3]:
# Import necessary libraries
import torch

In [4]:
# Lets print some information about these data types
print(f"Float 32 - {torch.finfo(torch.float32)}")
print(f"Float 16 - {torch.finfo(torch.float16)}")
print(f"Brain Float 16 - {torch.finfo(torch.bfloat16)}")
print(f"Int 8 - {torch.iinfo(torch.int8)}")
print(f"Unsigned Int 8 - {torch.iinfo(torch.uint8)}")

Float 32 - finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)
Float 16 - finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)
Brain Float 16 - finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)
Int 8 - iinfo(min=-128, max=127, dtype=int8)
Unsigned Int 8 - iinfo(min=0, max=255, dtype=uint8)


In [5]:
value = 1/3 # By default, Python stores this value in 64-bit format

# Create tensors in various datatypes
tensor_fp32 = torch.tensor(value, dtype = torch.float32)
tensor_fp16 = torch.tensor(value, dtype = torch.float16)

tensor_bf16 = torch.tensor(value, dtype = torch.bfloat16)

In [6]:
def highlight_trailing_zeros(value_str):
    # Split the string at the first occurrence of trailing zeros
    split_index = len(value_str.rstrip('0'))
    return f"\033[32m{value_str[:split_index]}\033[33m{value_str[split_index:]}\033[0m"


print(f"64-bit base: {highlight_trailing_zeros(format(value, '.60f'))}")
print(f"tensor_fp32: {highlight_trailing_zeros(format(tensor_fp32.item(), '.60f'))}")
print(f"tensor_fp16: {highlight_trailing_zeros(format(tensor_fp16.item(), '.60f'))}")
print(f"tensor_bf16: {highlight_trailing_zeros(format(tensor_bf16.item(), '.60f'))}")

64-bit base: [32m0.333333333333333314829616256247390992939472198486328125[33m000000[0m
tensor_fp32: [32m0.3333333432674407958984375[33m00000000000000000000000000000000000[0m
tensor_fp16: [32m0.333251953125[33m000000000000000000000000000000000000000000000000[0m
tensor_bf16: [32m0.333984375[33m000000000000000000000000000000000000000000000000000[0m


In [7]:
# Downcasting example
tensor_fp32 = torch.rand(128, dtype=torch.float32)

tensor_fp16 = tensor_fp32.to(dtype=torch.float16)
tensor_bf16 = tensor_fp32.to(dtype=torch.bfloat16)

In [8]:
# Let's perform a simple dot product and see the difference in values

m_float32 = torch.dot(tensor_fp32, tensor_fp32)
m_float16 = torch.dot(tensor_fp16, tensor_fp16)
m_bfloat16 = torch.dot(tensor_bf16, tensor_bf16)

print(f"fp32: {highlight_trailing_zeros(format(m_float32, '.20f'))}")
print(f"fp16: {highlight_trailing_zeros(format(m_float16, '.20f'))}")
print(f"bf16: {highlight_trailing_zeros(format(m_bfloat16, '.20f'))}")

fp32: [32m39.338592529296875[33m00000[0m
fp16: [32m39.34375[33m000000000000000[0m
bf16: [32m39.25[33m000000000000000000[0m


## 8.2 Lab - Quantization - Downcasting

In this lab, we will load a model and downcast it into different datatypes. For each datatype, we will evaluate its accuracy on the test dataset. 

<small><i>Model credits: https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion</i></small>

<small><i>Data credits: https://huggingface.co/datasets/dair-ai/emotion</i></small>

In [9]:
# Import libraries
import torch
from torch.utils.data import DataLoader, Dataset

from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from datasets import load_dataset

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
# Check the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Load pretrained model
# Load pre-trained model and tokenizer
model_name = 'bhadresh-savani/distilbert-base-uncased-emotion'
model = DistilBertForSequenceClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

model.to(device)
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [12]:
# Load the dataset
ds = load_dataset("dair-ai/emotion", "split")

In [13]:
# Evaluate the original model
def evaluate(model, ds, batch_size=64, device='cuda'):
    model.eval()
    total, correct = 0, 0
    dataloader = DataLoader(ds['test'], batch_size=batch_size)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            tokenized = tokenizer(batch['text'], truncation=True, padding=True, return_tensors="pt")
            input_ids = tokenized.input_ids.to(device)
            attention_mask = tokenized.attention_mask.to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    return accuracy

original_accuracy = evaluate(model, ds)
print(f'\nOriginal model accuracy: {original_accuracy * 100:.2f}%')

Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.73it/s]


Original model accuracy: 92.70%





In [14]:
# Now lets create different variations of downcasted model
def create_model(datatype=torch.float16):
    model = DistilBertForSequenceClassification.from_pretrained(model_name)
    model.to(datatype)
    model.to(device)
    model.eval()
    return model

model_fp32 = create_model(torch.float32)
model_fp16 = create_model(torch.float16)
model_bf16 = create_model(torch.bfloat16)

accuracy_fp32 = evaluate(model_fp32, ds)
accuracy_fp16 = evaluate(model_fp16, ds)
accuracy_bf16 = evaluate(model_bf16, ds)

print(f'\nFP32 model accuracy: {accuracy_fp32 * 100:.2f}%')
print(f'FP16 model accuracy: {accuracy_fp16 * 100:.2f}%')
print(f'BF16 model accuracy: {accuracy_bf16 * 100:.2f}%')

Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:03<00:00, 10.22it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 25.15it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.61it/s]


FP32 model accuracy: 92.70%
FP16 model accuracy: 92.70%
BF16 model accuracy: 92.70%





In [15]:
# Lets check these models runtime memory footprint

print(f"FP32: {model_fp32.get_memory_footprint() / (1024 * 1024)} MB")
print(f"FP16: {model_fp16.get_memory_footprint() / (1024 * 1024)} MB")
print(f"BF16: {model_bf16.get_memory_footprint() / (1024 * 1024)} MB")

FP32: 255.4287338256836 MB
FP16: 127.7163200378418 MB
BF16: 127.7163200378418 MB


As you can see from above results, there is no notable degradation in performance between FP32, FP16 and BF16 datatypes. 

NOTE - Observe the speed-ups achieved with FP16 inference on the evaluation dataset.

From this we can conclude that, just by deploying the model in FP16 format, we will be able to get `50 %` reduction in model size and `> 2x` speed-up in inference. 

Now let's go a step further and perform integer quantization.

# 9. Quantization - Integer weight-only

In this section, we explore **Integer Quantization**, a method where floating-point values are mapped to integers, reducing memory footprint and improving inference speed. Integer quantization is especially popular for deployment in edge devices and hardware accelerators like TPUs or GPUs.

## Post-Training Integer Quantization

**Post-Training Quantization** is applied after the model has been trained, without requiring model retraining. Weights and activations are mapped from floating-point (usually FP32) to integers (e.g., INT8) using a **scale** and a **zero-point**.

### Scale and Zero-Point

- **Scale**: A multiplicative factor that scales the floating-point range to the integer range.
- **Zero-point**: An integer value that maps floating-point zero to an integer value, compensating for any offsets in the original range.

For example:
- **Scale**: `0.1`
- **Zero-point**: `0`
- A floating-point value of `0.5` will be quantized as `int8(0.5 / 0.1) = 5`.

Post-training integer quantization does not require retraining, making it easy to implement. It is particularly useful when speed and memory savings are more important than exact precision.

## Per-Tensor Quantization

In **Per-Tensor Quantization**, a **single scale and zero-point** are used for the entire tensor (e.g., a weight matrix or activation map). 

- It’s faster because only one scale and zero-point are used for the whole tensor.
- It works well when the tensor data has a uniform distribution.
- **Drawback**: It can lead to reduced precision when the tensor's values vary significantly.

## Per-Channel Quantization

**Per-Channel Quantization** applies different quantization parameters (scale and zero-point) to **individual channels** of the tensor, which is especially useful in convolutional layers of neural networks.

- Each channel (e.g., weight filter in a CNN) has its own scale and zero-point.
- This leads to better accuracy compared to per-tensor quantization, especially when different channels have varying data distributions.
- **Drawback**: Slightly more computational overhead because each channel needs its own scale and zero-point.

## Symmetric vs. Asymmetric Quantization

- **Symmetric Quantization**: The zero-point is fixed at `0`, meaning the range is symmetric around zero. This simplifies calculations and is often used in models that are well-centered around zero.
- **Asymmetric Quantization**: Allows for a non-zero zero-point, which provides more flexibility in cases where the data distribution isn’t centered around zero.

---

# Example of Integer Quantization in PyTorch

Let's see how we can apply **integer quantization** to a PyTorch model.

In [16]:
import torch
import torch.quantization

# Example: Simulating an INT8 model from an FP32 model using Post-Training Quantization
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
    
    def forward(self, x):
        x = self.fc1(x)
        return self.fc2(x)

# Create an FP32 model
model_fp32 = SimpleModel()

# Apply dynamic quantization to convert it to INT8
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32, {torch.nn.Linear}, dtype=torch.qint8
)

# Test with some random input
input_fp32 = torch.randn(1, 128)
output_fp32 = model_fp32(input_fp32)
output_int8 = model_int8(input_fp32)

print(f"Output from FP32 quantized model: {output_fp32}")
print(f"Output from INT8 quantized model: {output_int8}")

Output from FP32 quantized model: tensor([[ 0.0033,  0.4057, -0.1154,  0.2642, -0.0642, -0.2054, -0.1421,  0.3769,
         -0.0029,  0.1919,  0.4293,  0.0696,  0.1322,  0.4005,  0.1255, -0.1972,
          0.1423,  0.3478,  0.1480, -0.1819,  0.1038, -0.0499,  0.0911, -0.2550,
          0.2613,  0.2535, -0.0996, -0.2911, -0.8366,  0.0810,  0.5041,  0.2107]],
       grad_fn=<AddmmBackward0>)
Output from INT8 quantized model: tensor([[ 0.0087,  0.4169, -0.1180,  0.2583, -0.0688, -0.2049, -0.1453,  0.3776,
          0.0034,  0.1910,  0.4294,  0.0789,  0.1383,  0.3942,  0.1181, -0.1975,
          0.1439,  0.3552,  0.1535, -0.1834,  0.0965, -0.0472,  0.0986, -0.2555,
          0.2624,  0.2425, -0.1019, -0.2953, -0.8378,  0.0809,  0.5058,  0.2076]])


# 10. Lab - Integer Quantization in LLMs

When it comes to LLMs, **BitsAndBytes** is the easiest option for quantizing a model to 8-bit and 4-bit precision. For 8-bit quantization, it handles outliers by multiplying them in FP16 with non-outliers in INT8, converting the non-outlier values back to FP16, and then summing them up to return the weights in FP16. This approach minimizes the negative impact outlier values can have on a model’s performance. Meanwhile, 4-bit quantization goes a step further in compressing the model, commonly used alongside **QLoRA** for fine-tuning quantized large language models (LLMs).

To learn more about the BitsAndBytes, check this link - https://github.com/bitsandbytes-foundation/bitsandbytes

In [17]:
from transformers import BitsAndBytesConfig
from peft import PeftConfig, get_peft_model

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    quantization_config=quantization_config
)

model_8bit.eval()

`low_cpu_mem_usage` was None, now set to True since model is quantized.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear8bitLt(in_features=768, out_features=768, bias=True)
            (k_lin): Linear8bitLt(in_features=768, out_features=768, bias=True)
            (v_lin): Linear8bitLt(in_features=768, out_features=768, bias=True)
            (out_lin): Linear8bitLt(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout

In [18]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

model_4bit = DistilBertForSequenceClassification.from_pretrained(
    model_name, 
    quantization_config=quantization_config
)
model_4bit.eval()

`low_cpu_mem_usage` was None, now set to True since model is quantized.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (k_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (v_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (out_lin): Linear4bit(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, 

In [19]:
# Lets check these models runtime memory footprint

print(f"8 bit: {model_8bit.get_memory_footprint() / (1024 * 1024)} MB")
print(f"4 bit: {model_4bit.get_memory_footprint() / (1024 * 1024)} MB")

8 bit: 86.6538200378418 MB
4 bit: 66.1225700378418 MB


In [20]:
# Let us choose 4-bit model and check its accuracy
accuracy_8bit = evaluate(model_8bit, ds)
accuracy_4bit = evaluate(model_4bit, ds)

print(f'\n8 bit model accuracy: {accuracy_8bit * 100:.2f}%')
print(f'4 bit model accuracy: {accuracy_4bit * 100:.2f}%')

Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.33it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 21.57it/s]


8 bit model accuracy: 28.40%
4 bit model accuracy: 92.35%





In this specific example, the model we used was able to work with decent accuracy when quantized to 4-bit. In practice, we perform more rigorous evaluation and if the 4-bit model is up-to the desired quality standards, we should be able to productionize it without further training.

In most cases, we will need to perform quantization-aware-training or PEFT (parameter-efficient fine tuning) to reach decent accuracy numbers. As you can see, the 8-bit model's accuracy significantly dropped, we can perform PEFT to improve its accuracy. For a lab on this please refer to the notebook - [PEFT-DistilBert.ipynb](./PEFT-DistilBert.ipynb)

# Next Steps

To learn more about the quantization in LLMs and follow some latest techniques like GPTQ and AWQ, please refer to the official HuggingFace documentation - [HuggingFace Quantization](https://huggingface.co/docs/transformers/v4.45.1/quantization/overview). 

HuggingFace's Transformers library does support majority of the latest quantization techniques and the following table reflects their current support matrix - 

![Support Matrix](assets/hf_quant_support_matrix.png)

<small><i>Source: https://huggingface.co/docs/transformers/v4.45.1/quantization/overview</i></small>
