## Naive 8-bit Quantization
- absolute maximum (absmax) quantization
- zero-point quantization

### Absolute maximum (absmax) Quantization

In [1]:
import torch

def absmax_quantize(X):
    scale = 127 / torch.max(torch.abs(X))
    
    X_quant = torch.round(scale * X)
    X_dequant = X_quant / scale

    return X_quant.to(torch.int8), X_dequant

In [2]:
weights = torch.randn((3, 3))
weights_quant, weights_dequant = absmax_quantize(weights)
weights, weights_quant, weights_dequant

(tensor([[ 0.7168, -2.2098,  0.1761],
         [-1.6349, -0.7506, -0.8688],
         [ 1.7245,  0.1274, -0.3302]]),
 tensor([[  41, -127,   10],
         [ -94,  -43,  -50],
         [  99,    7,  -19]], dtype=torch.int8),
 tensor([[ 0.7134, -2.2098,  0.1740],
         [-1.6356, -0.7482, -0.8700],
         [ 1.7226,  0.1218, -0.3306]]))

### Zero-point Quantization

In [3]:
def zeropoint_quantize(X):
    x_range = torch.max(X) - torch.min(X)
    x_range = 1 if x_range == 0 else x_range

    scale = 255 / x_range
    zeropoint = torch.round(-scale * torch.min(X) - 128)

    X_quant = torch.clip(torch.round(X * scale + zeropoint), -128, 127)
    X_dequant = (X_quant - zeropoint) / scale

    return X_quant.to(torch.int8), X_dequant

In [4]:
weights = torch.randn((3, 3)) + 0.5
weights_quant, weights_dequant = zeropoint_quantize(weights)
weights, weights_quant, weights_dequant

(tensor([[ 2.3903,  0.2432,  1.2972],
         [ 0.9187,  1.7886,  0.2253],
         [-0.0748, -0.2997,  0.6005]]),
 tensor([[ 127,  -77,   23],
         [ -13,   70,  -79],
         [-107, -128,  -43]], dtype=torch.int8),
 tensor([[ 2.3947,  0.2426,  1.2975],
         [ 0.9178,  1.7934,  0.2215],
         [-0.0738, -0.2954,  0.6013]]))

### using transformers library
- GPT2 model

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch.manual_seed(0)

device = 'cpu'

model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir="/media/shin/T7/huggingface/models").to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/media/shin/T7/huggingface/tokenizers")

print(model.get_memory_footprint())

  from .autonotebook import tqdm as notebook_tqdm


510342192


#### test GPT-2 first attention layer

In [6]:
weights = model.transformer.h[0].attn.c_attn.weight.data
print("Original weights:")
print(weights)

# Quantize layer using absmax quantization
weights_abs_quant, _ = absmax_quantize(weights)
print("\nAbsmax quantized weights:")
print(weights_abs_quant)

# Quantize layer using absmax quantization
weights_zp_quant, _ = zeropoint_quantize(weights)
print("\nZero-point quantized weights:")
print(weights_zp_quant)

Original weights:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]])

Absmax quantized weights:
tensor([[-21, -12,  -4,  ...,   2,  -3,   1],
        [  4,   7,  11,  ...,  -2,  -1,  -1],
        [  0,   3,  16,  ...,   5,   2,  -1],
        ...,
        [-12,  -1,   9,  ...,   0,  -2,   1],
        [  7,  10,   5,  ...,   1,  -2,  -2],
        [-18,  -9, -11,  ...,   0,   0,   1]], dtype=torch.int8)

Zero-point quantized weights:
tensor([[-20, -11,  -3,  ...,   3,  -2,   2],
        [  5,   8,  12,  ...,  -1,   0,   0],
        [  1,   4,  18,  ...,   6,   3,   0],
        ...,
        [-11,   0,  10,  ...,  

#### quantization all layers of GPT-2

In [7]:
import numpy as np
from copy import deepcopy

# Store original weights
weights = [param.data.clone() for param in model.parameters()]

# Create model to quantize
model_abs = deepcopy(model)

# Quantize all model weights
weights_abs = []
for param in model_abs.parameters():
    _, dequantized = absmax_quantize(param.data)
    param.data = dequantized
    weights_abs.append(dequantized)

# Create model to quantize
model_zp = deepcopy(model)

# Quantize all model weights
weights_zp = []
for param in model_zp.parameters():
    _, dequantized = zeropoint_quantize(param.data)
    param.data = dequantized
    weights_zp.append(dequantized)

#### generate text

In [8]:
def generate_text(model, input_text, max_length=50):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    output = model.generate(
        input_ids,
        max_length=max_length,
        do_sample=True,
        top_k=50,
        pad_token_id=tokenizer.eos_token_id,
        attention_mask=input_ids.new_ones(input_ids.shape)
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Generate text with original and quantized models
original_text = generate_text(model, "I have a dream")
absmax_text   = generate_text(model_abs, "I have a dream")
zp_text       = generate_text(model_zp, "I have a dream")

print(f"Original model:\n{original_text}")
print("-" * 50)
print(f"Absmax model:\n{absmax_text}")
print("-" * 50)
print(f"Zeropoint model:\n{zp_text}")

Original model:
I have a dream," he said. "You work as long as you can. No excuses. I want to be someone that I can make an impact and do things I enjoy and it gives me a big motivation to keep doing what I'm doing
--------------------------------------------------
Absmax model:
I have a dream on my hands, to turn my knees on you, or something else; and I will soon get to that point." (Lying down) "If that is possible--then I'll use it. (Pause.)" *
--------------------------------------------------
Zeropoint model:
I have a dream for you at the moment.

It is a dream to share an intimate home with another human to spend their nights away from one's home while simultaneously exploring and enjoying both life and art. We see the same kind of art


#### evaluate with perplexity

In [9]:
def calculate_perplexity(model, text):
    encodings = tokenizer(text, return_tensors="pt").to(device)

    input_ids = encodings.input_ids
    target_ids = input_ids.clone()

    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=target_ids)
    
    neg_log_likelihood = outputs.loss
    ppl = torch.exp(neg_log_likelihood)
    return ppl

In [10]:
ppl     = calculate_perplexity(model, original_text)
ppl_abs = calculate_perplexity(model_abs, absmax_text)
ppl_zp  = calculate_perplexity(model_zp, absmax_text)

print(f"Original perplexity:  {ppl.item():.2f}")
print(f"Absmax perplexity:    {ppl_abs.item():.2f}")
print(f"Zeropoint perplexity: {ppl_zp.item():.2f}")

Original perplexity:  12.41
Absmax perplexity:    32.97
Zeropoint perplexity: 31.69


### 8-bit Quantization with LLM.int8()

In [11]:
model_id = "gpt2"

model_int8 = AutoModelForCausalLM.from_pretrained(
    model_id, 
    device_map='auto',  # gpu, cpu 순으로 최대한 load
    load_in_8bit= True, # 734MB -> 418MB
    cache_dir="/media/shin/T7/huggingface/models"
)
print(f"Model size: {model_int8.get_memory_footprint():,} bytes")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Model size: 176,527,896 bytes


#### original vs LLM.int8() by perplexity

In [12]:
text_int8 = generate_text(model_int8, "I have a dream")

print(f"Original model:\n{original_text}")
print("-" * 50)
print(f"LLM.int8() model:\n{text_int8}")



Original model:
I have a dream," he said. "You work as long as you can. No excuses. I want to be someone that I can make an impact and do things I enjoy and it gives me a big motivation to keep doing what I'm doing
--------------------------------------------------
LLM.int8() model:
I have a dream that will continue to be fulfilled. That's why I believe we need to bring back this legislation and that's why I strongly oppose it.

We have to bring back any illegal immigration. Even if we end up as part


In [13]:
print(f"Perplexity (original):   {ppl.item():.2f}")

ppl = calculate_perplexity(model_int8, text_int8)
print(f"Perplexity (LLM.int8()): {ppl.item():.2f}")

Perplexity (original):   12.41
Perplexity (LLM.int8()): 14.87


### Reference
https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c