# Quantisation

In [1]:
import copy
import random
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
random.seed(42)

if torch.cuda.is_available():
    logger.info("Using GPU")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    logger.info("Using MPS")
    device = torch.device("mps")
else:
    logger.info("Using CPU")
    device = torch.device("cpu")
device

INFO:__main__:Using MPS


device(type='mps')

## Helper Functions

In [4]:
def generate_token(ins: dict, model) -> (torch.Tensor, torch.Tensor):
    with torch.no_grad():
        _output = model(**ins)

    _next_token = _output.logits[:, -1, :].argmax(dim=1)
    return _output, _next_token

def chat(model, tokeniser, inputs_t0, no_of_tokens = 100):
    generated_tokens = dict()
    inputs_tx = inputs_t0
    if "position_ids" in inputs_t0:
        position_ids = inputs_t0["position_ids"]
    else:
        position_ids = None
    durations_cached_s = []
    for _ in range(no_of_tokens):
        t0 = time.time()
        output, next_token_ids = generate_token(inputs_tx, model)
        durations_cached_s.append(time.time() - t0)
        
        inputs_tx = {
            "input_ids": next_token_ids.reshape((-1, 1)),
            "attention_mask": torch.cat(
                [inputs_tx["attention_mask"], torch.ones((inputs_t0["input_ids"].shape[0], 1))],
                dim=1
            ),
            "past_key_values": output.past_key_values,
        }
        
        if position_ids is not None:
            position_ids = position_ids[:, -1].unsqueeze(-1) + 1
            inputs_tx["position_ids"] = position_ids
            
        
        next_tokens = tokeniser.batch_decode(next_token_ids.reshape((inputs_t0["input_ids"].shape[0], 1)))
        for i, token in enumerate(next_tokens):
            if i not in generated_tokens:
                generated_tokens[i] = []
            generated_tokens[i].append(token)
            
    return ["".join(generated_tokens[i]) for i in generated_tokens.keys()], durations_cached_s

# fix dtype post quantization to "pretend" to be fp32
def get_float32_dtype(self):
    return torch.float32

## Setup

In [5]:
gpt_tokeniser = AutoTokenizer.from_pretrained("openai-community/gpt2", )
gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

gpt_tokeniser.pad_token = gpt_tokeniser.eos_token
gpt2.config.pad_token_id = gpt2.config.eos_token_id

# pad on the left so we can append new tokens on the right
gpt_tokeniser.padding_side = "left"
gpt_tokeniser.truncation_side = "left"

gpt2

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [10]:
GPT2Model.dtype = property(get_float32_dtype)
logger.info(f" GPT2 footprint: {round(gpt2.get_memory_footprint() / (1024 * 1024), 2)} MB")

INFO:__main__: GPT2 footprint: 486.7 MB


## Scripting Quantisation

In [32]:
from typing import Tuple
def quantise(t: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    _min, _max = t.min(), t.max()
    scale = (_max - _min) / 255
    zero_point = _min
    t = (t - zero_point) / scale
    t = torch.clamp(t, min=0, max=255)
    t = t.type(torch.uint8)
    return t, (scale, zero_point)

def dequantise(t: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    scale, zero_point = state
    return t.to(torch.float32) * scale + zero_point

def quantise_model(model: torch.nn.Module) -> Tuple[torch.nn.Module, dict]:
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantise(param.data)
        states[name] = state
    return model, states

def dequantise_model(model: torch.nn.Module, states: dict) -> torch.nn.Module:
    for name, param in model.named_parameters():
        param.data = dequantise(param.data, states[name])
    return model

In [29]:
_t = gpt2.transformer.h[0].attn.c_attn.weight.data
logger.info(f" Min and Max before quantisation: {_t.min().item()}, {_t.max().item()}, {_t.dtype}")

_tq, _state = quantise(_t)
logger.info(f" Min and Max after quantisation: {_tq.min().item()}, {_tq.max().item()}, {_tq.dtype}")

_tc = dequantise(_tq, _state)
logger.info(f" Min and Max after dequantisation: {_tc.min().item()}, {_tc.max().item()}, {_tc.dtype}")

_total_loss = torch.abs(_t - _tc).sum()
logger.info(f" Total loss: {_total_loss}")

INFO:__main__: Min and Max before quantisation: -2.8436343669891357, 2.7956299781799316, torch.float32
INFO:__main__: Min and Max after quantisation: 0, 255, torch.uint8
INFO:__main__: Min and Max after dequantisation: -2.8436343669891357, 2.7956297397613525, torch.float32
INFO:__main__: Total loss: 19566.611328125


## Applying Quantisation

### Before

In [34]:
logger.info(f" GPT2 footprint: {round(gpt2.get_memory_footprint() / (1024 * 1024), 2)} MB")

INFO:__main__: GPT2 footprint: 486.7 MB


In [48]:
inputs_t0 = gpt_tokeniser("I woke up to the rain and as I looked outside the window", return_tensors="pt")
expected_response, _ = chat(gpt2, gpt_tokeniser, inputs_t0, 20)
expected_response

[' I saw a man standing in the middle of the street. He was wearing a black hoodie and']

### After

In [50]:
gpt2q, _states = quantise_model(copy.deepcopy(gpt2))
logger.info(f"After quantisation: {gpt2q.get_memory_footprint() / (1024 * 1024)} MB")

gpt2q = dequantise_model(gpt2q, _states)
logger.info(f"After dequantisation: {gpt2q.get_memory_footprint() / (1024 * 1024)} MB")

INFO:__main__:After quantisation: 130.6750946044922 MB
INFO:__main__:After dequantisation: 486.7002410888672 MB


In [53]:
quantised_response, _ = chat(gpt2q, gpt_tokeniser, inputs_t0, 20)
logger.info(f"Results:\nExpected Result:{expected_response[0]}\nQuantised Result:{quantised_response[0]}")

INFO:__main__:Results:
Expected Result: I saw a man standing in the middle of the street. He was wearing a black hoodie and
Quantised Result: I saw a man standing there with his head down on the ground. He was wearing a shirt and
