# L4-D - Building your own Quantizer: Load your Quantized Weights from Hugging Face Hub

In this lesson, you will learn memory efficient model loading.

The goal for this is, we can use any big machine to load a large model and quantize it
Then push the quantzied version so you can load it in your small machine.


In [55]:
import torch
from helpers import W8A16LinearLayer, W8A16LinearLayerDtype, replace_linear_with_target_and_quantize, replace_linear_with_target
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
from huggingface_hub import HfApi, create_repo
import os

ImportError: cannot import name 'W8A16LinearLayerDtype' from 'helpers' (/Users/tango.tew/Library/CloudStorage/OneDrive-EY/Documents/repos/AI-Projects/model-quantization/linear_quantizer/helpers.py)

In [5]:
model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(
  model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

#### Quantize and Save your weights

In [6]:
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["lm_head"])

# Make sure the model is quantized
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): W8A16LinearLayer()
            (v_proj): W8A16LinearLayer()
            (q_proj): W8A16LinearLayer()
            (out_proj): W8A16LinearLayer()
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): W8A16LinearLayer()
          (fc2): W8A16LinearLayer()
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)

#### Save your model

In [7]:
# Get the model weights into dictionary
model_dict = model.state_dict()
save_path = "../models/quantized/fb_125m_quantized_state_dict.pth"
torch.save(model_dict, save_path)

#### Use Online Repository to push your quantized weights
In this case we're going to use HuggingFace

The below code is for demonstration purposes only.

You'll need your own Hugging Face username in order for it to run.
You'll add your usernmae in YOUR_HF_USERNAME = ""

In [18]:
# load the environment variables
load_dotenv()

# get the api token and the username from the environment variables
USERNAME = os.getenv("USERNAME")
# Get the Hugging Face API token from the environment variable
api_token = os.getenv('HUGGINGFACE_HUB_TOKEN')

if not api_token:
  raise ValueError("No API token found. Please set HUGGINGFACE_HUB_TOKEN in your .env file.")

if not USERNAME:
  raise ValueError("No username found. Please set USERNAME in your .env file.")


repo_id = f"{USERNAME}/opt-125m-quantized-deeplearningai"

# instantiate the HfApi class
api = HfApi()

# create a new repository if not already created
create_repo(repo_id=repo_id, token=api_token, repo_type="model", exist_ok=True)

api.upload_file(
  path_or_fileobj=save_path,
  path_in_repo="opt_125m_quantized_state_dict.pth",
  repo_id=repo_id,
  token=api_token
)

fb_125m_quantized_state_dict.pth: 100%|██████████| 166M/166M [00:25<00:00, 6.54MB/s] 


CommitInfo(commit_url='https://huggingface.co/tew9/opt-125m-quantized-deeplearningai/commit/a2599c2c1c9d3016cad31e52b4ab124507ea62be', commit_message='Upload opt_125m_quantized_state_dict.pth with huggingface_hub', commit_description='', oid='a2599c2c1c9d3016cad31e52b4ab124507ea62be', pr_url=None, pr_revision=None, pr_num=None)

## Load the Quantized Weights from huggingface
load the meta device(pytorch skeleton of the model - to be used to load the quantized weights from the hugging face) to make the model architecture match the one on the hugging face

- We use meta device to load the skeleton of the model config(architecture) but without loading the weights themselves which save us memory

In [58]:
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig

model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(model_id)

with torch.device("meta"):
    model_arch = OPTForCausalLM(config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
# load quantized weights from the hugging face

for param in model_arch.parameters():
  print(param)


Parameter containing:
tensor(..., device='meta', size=(50272, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2050, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(768,), requires_

As we can see the tensor do not have weights at all (...), but we have the entire model architecture

In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from helpers import  W8A16LinearLayer
def w8_a16_forward(weight, input, scales, bias=None):
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scale", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)
        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)
        int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights.data = int8_weights
        self.scale.data = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scale, self.bias)

def replace_linear_with_target(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and name not in module_name_to_exclude:
            old_bias = child.bias
            
            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            new_module.to(child.weight.device)
            
            setattr(module, name, new_module)
            
            # Copy weights and bias (if exists)
            new_module.quantize(child.weight.data)
            if old_bias is not None:
                new_module.bias.data = old_bias.data.clone().to(new_module.bias.dtype)
                
        else:
            replace_linear_with_target(child, target_class, module_name_to_exclude)

# # Example usage
# # from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig

# # model_id = "facebook/opt-125m"
# # config = AutoConfig.from_pretrained(model_id)

# # with torch.device("meta"):
# #     model_arch = OPTForCausalLM(config)

# # tokenizer = AutoTokenizer.from_pretrained(model_id)

# replace_linear_with_target(model_arch, W8A16LinearLayer, ["lm_head"])


In [77]:
replace_linear_with_target(model, W8A16LinearLayer, ["lm_head"])
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): W8A16LinearLayer()
            (v_proj): W8A16LinearLayer()
            (q_proj): W8A16LinearLayer()
            (out_proj): W8A16LinearLayer()
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): W8A16LinearLayerDtype()
          (fc2): W8A16LinearLayerDtype()
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1-11): 11 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): W8A16LinearLayerDtype()
            (v_proj): W8A16LinearLayerD

### Load the model weights from hugging face

In [78]:
from huggingface_hub import hf_hub_download

# Download the quantized state dict and cache it
state_dict_cache_path = hf_hub_download(repo_id, "opt_125m_quantized_state_dict.pth")
# Load the quantized state dict
quantized_state_dict = torch.load(state_dict_cache_path)

# Remove incompatible keys
quantized_state_dict = {k.replace('scales', 'scale'): v for k, v in quantized_state_dict.items()}

# check the shape of one of the tensors
# Assign the model from the meta device the weight from the quantized state dict
model.load_state_dict(quantized_state_dict)

RuntimeError: Error(s) in loading state_dict for OPTForCausalLM:
	Missing key(s) in state_dict: "model.decoder.layers.0.fc1.scales", "model.decoder.layers.0.fc2.scales", "model.decoder.layers.1.self_attn.k_proj.scales", "model.decoder.layers.1.self_attn.v_proj.scales", "model.decoder.layers.1.self_attn.q_proj.scales", "model.decoder.layers.1.self_attn.out_proj.scales", "model.decoder.layers.1.fc1.scales", "model.decoder.layers.1.fc2.scales", "model.decoder.layers.2.self_attn.k_proj.scales", "model.decoder.layers.2.self_attn.v_proj.scales", "model.decoder.layers.2.self_attn.q_proj.scales", "model.decoder.layers.2.self_attn.out_proj.scales", "model.decoder.layers.2.fc1.scales", "model.decoder.layers.2.fc2.scales", "model.decoder.layers.3.self_attn.k_proj.scales", "model.decoder.layers.3.self_attn.v_proj.scales", "model.decoder.layers.3.self_attn.q_proj.scales", "model.decoder.layers.3.self_attn.out_proj.scales", "model.decoder.layers.3.fc1.scales", "model.decoder.layers.3.fc2.scales", "model.decoder.layers.4.self_attn.k_proj.scales", "model.decoder.layers.4.self_attn.v_proj.scales", "model.decoder.layers.4.self_attn.q_proj.scales", "model.decoder.layers.4.self_attn.out_proj.scales", "model.decoder.layers.4.fc1.scales", "model.decoder.layers.4.fc2.scales", "model.decoder.layers.5.self_attn.k_proj.scales", "model.decoder.layers.5.self_attn.v_proj.scales", "model.decoder.layers.5.self_attn.q_proj.scales", "model.decoder.layers.5.self_attn.out_proj.scales", "model.decoder.layers.5.fc1.scales", "model.decoder.layers.5.fc2.scales", "model.decoder.layers.6.self_attn.k_proj.scales", "model.decoder.layers.6.self_attn.v_proj.scales", "model.decoder.layers.6.self_attn.q_proj.scales", "model.decoder.layers.6.self_attn.out_proj.scales", "model.decoder.layers.6.fc1.scales", "model.decoder.layers.6.fc2.scales", "model.decoder.layers.7.self_attn.k_proj.scales", "model.decoder.layers.7.self_attn.v_proj.scales", "model.decoder.layers.7.self_attn.q_proj.scales", "model.decoder.layers.7.self_attn.out_proj.scales", "model.decoder.layers.7.fc1.scales", "model.decoder.layers.7.fc2.scales", "model.decoder.layers.8.self_attn.k_proj.scales", "model.decoder.layers.8.self_attn.v_proj.scales", "model.decoder.layers.8.self_attn.q_proj.scales", "model.decoder.layers.8.self_attn.out_proj.scales", "model.decoder.layers.8.fc1.scales", "model.decoder.layers.8.fc2.scales", "model.decoder.layers.9.self_attn.k_proj.scales", "model.decoder.layers.9.self_attn.v_proj.scales", "model.decoder.layers.9.self_attn.q_proj.scales", "model.decoder.layers.9.self_attn.out_proj.scales", "model.decoder.layers.9.fc1.scales", "model.decoder.layers.9.fc2.scales", "model.decoder.layers.10.self_attn.k_proj.scales", "model.decoder.layers.10.self_attn.v_proj.scales", "model.decoder.layers.10.self_attn.q_proj.scales", "model.decoder.layers.10.self_attn.out_proj.scales", "model.decoder.layers.10.fc1.scales", "model.decoder.layers.10.fc2.scales", "model.decoder.layers.11.self_attn.k_proj.scales", "model.decoder.layers.11.self_attn.v_proj.scales", "model.decoder.layers.11.self_attn.q_proj.scales", "model.decoder.layers.11.self_attn.out_proj.scales", "model.decoder.layers.11.fc1.scales", "model.decoder.layers.11.fc2.scales". 
	Unexpected key(s) in state_dict: "model.decoder.layers.0.fc1.scale", "model.decoder.layers.0.fc2.scale", "model.decoder.layers.1.self_attn.k_proj.scale", "model.decoder.layers.1.self_attn.v_proj.scale", "model.decoder.layers.1.self_attn.q_proj.scale", "model.decoder.layers.1.self_attn.out_proj.scale", "model.decoder.layers.1.fc1.scale", "model.decoder.layers.1.fc2.scale", "model.decoder.layers.2.self_attn.k_proj.scale", "model.decoder.layers.2.self_attn.v_proj.scale", "model.decoder.layers.2.self_attn.q_proj.scale", "model.decoder.layers.2.self_attn.out_proj.scale", "model.decoder.layers.2.fc1.scale", "model.decoder.layers.2.fc2.scale", "model.decoder.layers.3.self_attn.k_proj.scale", "model.decoder.layers.3.self_attn.v_proj.scale", "model.decoder.layers.3.self_attn.q_proj.scale", "model.decoder.layers.3.self_attn.out_proj.scale", "model.decoder.layers.3.fc1.scale", "model.decoder.layers.3.fc2.scale", "model.decoder.layers.4.self_attn.k_proj.scale", "model.decoder.layers.4.self_attn.v_proj.scale", "model.decoder.layers.4.self_attn.q_proj.scale", "model.decoder.layers.4.self_attn.out_proj.scale", "model.decoder.layers.4.fc1.scale", "model.decoder.layers.4.fc2.scale", "model.decoder.layers.5.self_attn.k_proj.scale", "model.decoder.layers.5.self_attn.v_proj.scale", "model.decoder.layers.5.self_attn.q_proj.scale", "model.decoder.layers.5.self_attn.out_proj.scale", "model.decoder.layers.5.fc1.scale", "model.decoder.layers.5.fc2.scale", "model.decoder.layers.6.self_attn.k_proj.scale", "model.decoder.layers.6.self_attn.v_proj.scale", "model.decoder.layers.6.self_attn.q_proj.scale", "model.decoder.layers.6.self_attn.out_proj.scale", "model.decoder.layers.6.fc1.scale", "model.decoder.layers.6.fc2.scale", "model.decoder.layers.7.self_attn.k_proj.scale", "model.decoder.layers.7.self_attn.v_proj.scale", "model.decoder.layers.7.self_attn.q_proj.scale", "model.decoder.layers.7.self_attn.out_proj.scale", "model.decoder.layers.7.fc1.scale", "model.decoder.layers.7.fc2.scale", "model.decoder.layers.8.self_attn.k_proj.scale", "model.decoder.layers.8.self_attn.v_proj.scale", "model.decoder.layers.8.self_attn.q_proj.scale", "model.decoder.layers.8.self_attn.out_proj.scale", "model.decoder.layers.8.fc1.scale", "model.decoder.layers.8.fc2.scale", "model.decoder.layers.9.self_attn.k_proj.scale", "model.decoder.layers.9.self_attn.v_proj.scale", "model.decoder.layers.9.self_attn.q_proj.scale", "model.decoder.layers.9.self_attn.out_proj.scale", "model.decoder.layers.9.fc1.scale", "model.decoder.layers.9.fc2.scale", "model.decoder.layers.10.self_attn.k_proj.scale", "model.decoder.layers.10.self_attn.v_proj.scale", "model.decoder.layers.10.self_attn.q_proj.scale", "model.decoder.layers.10.self_attn.out_proj.scale", "model.decoder.layers.10.fc1.scale", "model.decoder.layers.10.fc2.scale", "model.decoder.layers.11.self_attn.k_proj.scale", "model.decoder.layers.11.self_attn.v_proj.scale", "model.decoder.layers.11.self_attn.q_proj.scale", "model.decoder.layers.11.self_attn.out_proj.scale", "model.decoder.layers.11.fc1.scale", "model.decoder.layers.11.fc2.scale". 
	size mismatch for model.decoder.layers.0.self_attn.k_proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1, 768]).
	size mismatch for model.decoder.layers.0.self_attn.v_proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1, 768]).
	size mismatch for model.decoder.layers.0.self_attn.q_proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1, 768]).
	size mismatch for model.decoder.layers.0.self_attn.out_proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1, 768]).
	size mismatch for model.decoder.layers.0.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1, 3072]).
	size mismatch for model.decoder.layers.0.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1, 768]).