<a href="https://colab.research.google.com/github/ekkiprop/llms/blob/main/G02_Quantized_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# If you're running on Colab
!pip install datasets bitsandbytes trl
import numpy as np
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate.utils.modeling import find_tied_parameters, get_mixed_precision_context_manager
from accelerate.utils.operations import convert_outputs_to_fp32
from bitsandbytes.nn import Linear8bitLt, Linear4bit, LinearFP4, LinearNF4
from collections import Counter
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, AutoConfig
from transformers.integrations.bitsandbytes import get_keys_to_not_convert
from types import MethodType
from matplotlib import pyplot as plt


Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.1-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Collecting trl
  Downloading trl-0.14.0-py3-none-any.whl.metadata (12 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch~=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.

In [None]:
def model_size(num_params, num_bits):

  return (num_params*(num_bits/8)/1e6)

In [None]:
model_size(360e6, 4)

In [None]:
torch.manual_seed(11)
weights = torch.randn(1000) * .07
weights.min(), weights.max()

In [None]:
# binning the values in to 4 bins
n_bins = 4
bins =  torch.linspace(weights.min(), weights.max(), n_bins+1)
bin_width = bins[1] - bins[0]
bins, bin_width

In [None]:
bins


In [None]:
import matplotlib.pyplot as plt
plt.hist(weights, bins)

In [None]:
# Computing the bind indexes for each weight
bin_indexes = (weights.view(-1, 1)>bins).to(torch.int).argmin(dim=1) * 1
print(weights[:20], bin_indexes[:20])

In [None]:
bins

In [None]:
bin_values = bins[:-1]
first_bin = bin_values[0]
bin_values

In [None]:
first_bin

In [None]:
# retrieving the approximate original values
torch.arange(0, n_bins) * bin_width + first_bin

In [None]:
approx_values  = bin_indexes * bin_width + first_bin
print(approx_values[:20])

##### Use MSE to check simillarity between the original weights and the approximated values


In [None]:
mse_fn = nn.MSELoss()
mse_fn(approx_values, weights).sqrt()

## Functions for Quantization and Dequantization



In [None]:
def quantize(weights, n_bits=8):
  assert n_bits <=16, "Using more bits may result in slow execution and/or crashing."
  n_bits  = 2 ** n_bits
  bins = torch.linspace(weights.min(), weights.max(), n_bits+1)
  first_bin =  bins[0]
  bin_width = bins[1] - bins[0]
  bin_indexes = ((weights.view(-1, 1) > bins).to(torch.int).argmin(dim=1) * 1)
  return bin_indexes, bin_width, first_bin


def dequantize(bin_indexes, bin_width, first_bin):
  approx_values = bin_indexes * bin_width + first_bin
  return approx_values

In [None]:
# Comparing RMSE of quantization choices
for n_bits in [2, 4, 8,16]:
  res = quantize(weights, n_bits=n_bits)
  approx_values = dequantize(*res)
  print(f'{n_bits}-bit Quantization:')
  print(approx_values[:6])
  print(weights[:6])
  print(mse_fn(approx_values, weights).sqrt())

In [None]:
weights.dtype


In [None]:
fp16_weights = weights.to(torch.float16)
fp16_weights.dtype

In [None]:
weights

In [None]:
mse_fn(fp16_weights, weights)

In [None]:
torch.manual_seed(14)
tiny_values = torch.randn(1000) * 1e-5
fp16_tiny_values = tiny_values.to(torch.float16)
mse_fn(fp16_tiny_values, tiny_values)

In [None]:

print(tiny_values[155:160])
print(fp16_tiny_values[155:160])

In [None]:
torch.manual_seed(19)
large_values = torch.randn(1000) * 1e5
fp16_large_values = large_values.to(torch.float16)
print(large_values[:5])
print(fp16_large_values[:5])

In [None]:
fp16_info = torch.finfo(torch.float16)
fp16_info

In [None]:
fp32_info = torch.finfo(torch.float32)
fp32_info

In [None]:
bf16_info = torch.finfo(torch.bfloat16)
fp16_info, bf16_info, fp32_info

In [None]:
smallest_subnormal = fp16_info.smallest_normal * 2**-10
smallest_subnormal


In [None]:
x = torch.tensor([0.5555555555])
torch.set_printoptions(precision= 9)
print(x)
print(x.to(torch.float32))
print(x.to(torch.float16))
print(x.to(torch.bfloat16))
torch.set_printoptions(precision=4)

### Loading Models

In [2]:
def get_parm_dtypes(iterable, top_k=3):
  return Counter([p.dtype for p in iterable]).most_common(top_k)

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", device_map = 'cuda:0')
print(model.get_memory_footprint()/1e6, get_parm_dtypes(model.parameters()))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

1324.785664 [(torch.float32, 388)]


In [5]:
#!wget https://huggingface.co/facebook/opt-350m/resolve/main/pytorch_model.bin
!ls -la pytorch_model.bin

-rw-r--r-- 1 root root 662513657 May 11  2022 pytorch_model.bin


In [6]:
state_dict = torch.load('pytorch_model.bin')
get_parm_dtypes(iter(state_dict.values()))

  state_dict = torch.load('pytorch_model.bin')


[(torch.float16, 388)]

In [8]:
model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map = 'cuda:0', torch_dtype = torch.float32
)


In [9]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features

In [10]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
batch = tokenizer(['This is a simple test'], return_tensors= 'pt')
batch['labels'] = batch['input_ids']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch = {k: v.to(device) for k, v in batch.items()}

out=model(**batch)
out.loss

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

tensor(3.8001, device='cuda:0', grad_fn=<NllLossBackward0>)

In [11]:
batch

{'input_ids': tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0'),
 'labels': tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0')}

In [12]:
batch.items()

dict_items([('input_ids', tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0')), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')), ('labels', tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0'))])

In [15]:
batch2 = tokenizer(['This is a simple test'], return_tensors= 'pt')
batch2

{'input_ids': tensor([[   2,  713,   16,   10, 2007, 1296]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [18]:
x = batch.items()

x

dict_items([('input_ids', tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0')), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')), ('labels', tensor([[   2,  713,   16,   10, 2007, 1296]], device='cuda:0'))])

In [21]:
out.loss

tensor(3.8001, device='cuda:0', grad_fn=<NllLossBackward0>)

In [23]:
supported =torch.cuda.is_bf16_supported(including_emulation=False)
dtype16 = (torch.bfloat16 if supported else torch.float16)
dtype16

torch.bfloat16

In [24]:
model.get_memory_footprint()/1e6

1324.785664

In [33]:
model.to(torch.bfloat16)
print(model.get_memory_footprint()/1e6, get_parm_dtypes(model.parameters()))

662.392832 [(torch.bfloat16, 388)]


In [30]:
model2 = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map = 'cuda:0'
)
print(model2.get_memory_footprint()/1e6, get_parm_dtypes(model2.parameters()))

1324.785664 [(torch.float32, 388)]


In [34]:
out = model(**batch)
out2 = model2(**batch)
out.loss, out2.loss

(tensor(3.8125, device='cuda:0', dtype=torch.bfloat16,
        grad_fn=<NllLossBackward0>),
 tensor(3.8001, device='cuda:0', grad_fn=<NllLossBackward0>))

### Mixed precision


In [61]:
class MixedModel(nn.Module):
  def __init__(self, dtype):
    super().__init__()
    self.a = nn.Linear(1000, 1000, dtype=dtype)
    self.b = nn.Linear(1000, 1000, dtype=dtype)

  def forward(self, x):
    return self.b(self.a(x))

In [62]:
mixed32 = MixedModel(torch.float32)
mixed32.to("cuda")

MixedModel(
  (a): Linear(in_features=1000, out_features=1000, bias=True)
  (b): Linear(in_features=1000, out_features=1000, bias=True)
)

In [48]:
%timeit mixed32(torch.randn(1000, 1000, dtype=torch.float32, device='cuda'))

487 µs ± 7.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [49]:
mixed16 = MixedModel(torch.bfloat16)
mixed16.to("cuda")
%timeit mixed16(torch.randn(1000, 1000, dtype=torch.bfloat16, device="cuda"))

121 µs ± 472 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [50]:
mmixed16 = MixedModel(torch.float16)
mmixed16.to("cuda")
%timeit mmixed16(torch.randn(1000, 1000, dtype=torch.float16, device="cuda"))

126 µs ± 198 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Using PyTorch autocast Context manager


In [52]:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
  %timeit mixed32(torch.randn(1000, 1000, dtype=torch.float32, device="cuda"))

133 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [63]:
autocast_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
# Original forward method
model_forward_func = mixed32.forward.__func__
# wrapping the method with the new context manager
new_forward = autocast_context(model_forward_func)
# assign the wrapped method back to the model

In [64]:
mixed32.forward = MethodType(new_forward, mixed32)

In [65]:
res = mixed32(torch.randn(1000, 1000, dtype=torch.float32, device="cuda"))
res.dtype

torch.bfloat16

In [73]:
mixed32.forward =MethodType(convert_outputs_to_fp32(mixed32.forward.__func__), mixed32)

In [74]:
res  =mixed32(torch.randn(1000, 1000, dtype=torch.float32, device="cuda"))
res.dtype

torch.float32

In [75]:
%timeit mixed32(torch.randn(1000, 1000, dtype=torch.float32, device="cuda"))

246 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### BitsAndBytes

In [78]:
bnb_config = BitsAndBytesConfig()
bnb_config

BitsAndBytesConfig {
  "_load_in_4bit": false,
  "_load_in_8bit": false,
  "bnb_4bit_compute_dtype": "float32",
  "bnb_4bit_quant_storage": "uint8",
  "bnb_4bit_quant_type": "fp4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": false,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

In [87]:
bnb_config_q8 =  BitsAndBytesConfig(load_in_8bit=True)
model_q8 = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map= "cuda:0", quantization_config =  bnb_config_q8, torch_dtype=torch.float32
)

print(model_q8.get_memory_footprint()/1e6, get_parm_dtypes(model_q8.parameters()))

415.670272 [(torch.float32, 242), (torch.int8, 146)]


In [88]:
out = model_q8(**batch)
out.loss



tensor(3.7965, device='cuda:0', dtype=torch.float32,
       grad_fn=<NllLossBackward0>)

In [98]:
bnb_config_q8 =  BitsAndBytesConfig(load_in_8bit=True)
model_q8_32 = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map= "cuda:0", quantization_config =  bnb_config_q8, torch_dtype=torch.float32
)

In [99]:
dec_layer = model_q8_32.model.decoder.layers[0]
dec_layer

OPTDecoderLayer(
  (self_attn): OPTSdpaAttention(
    (k_proj): Linear8bitLt(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear8bitLt(in_features=1024, out_features=1024, bias=True)
    (q_proj): Linear8bitLt(in_features=1024, out_features=1024, bias=True)
    (out_proj): Linear8bitLt(in_features=1024, out_features=1024, bias=True)
  )
  (activation_fn): ReLU()
  (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear8bitLt(in_features=1024, out_features=4096, bias=True)
  (fc2): Linear8bitLt(in_features=4096, out_features=1024, bias=True)
  (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

In [100]:
q8_layer= dec_layer.self_attn.k_proj
q8_layer

Linear8bitLt(in_features=1024, out_features=1024, bias=True)

In [102]:
q8_state = q8_layer.state_dict()
q8_state

OrderedDict([('weight',
              tensor([[ -67, -113,  -89,  ...,   65,  -16,  -87],
                      [  60,  120,   90,  ...,  -50,   32,   80],
                      [  47,  127,   86,  ...,  -34,    8,   90],
                      ...,
                      [ -65,   65,   34,  ...,  -64,   35,   64],
                      [  57,   67,   21,  ...,   63,  -64,  -64],
                      [ -64,   63,  -11,  ...,  -64,   34,   63]], device='cuda:0',
                     dtype=torch.int8)),
             ('bias',
              tensor([-0.0134,  0.0082,  0.0161,  ..., -0.0242, -0.0150,  0.0203],
                     device='cuda:0', dtype=torch.float32)),
             ('SCB',
              tensor([0.1250, 0.1252, 0.1250,  ..., 0.1252, 0.1250, 0.1254], device='cuda:0',
                     dtype=torch.float32)),
             ('weight_format', tensor(0, dtype=torch.uint8))])

In [103]:
print(model.model.decoder.embed_tokens)
print(model.lm_head)


Embedding(50272, 512, padding_idx=1)
Linear(in_features=512, out_features=50272, bias=False)


In [105]:
torch.allclose(model.model.decoder.embed_tokens.weight, model.lm_head.weight)

True

In [107]:
config = AutoConfig.from_pretrained('facebook/opt-350m')
config.tie_word_embeddings

True

In [108]:
find_tied_parameters(model)

[['lm_head.weight', 'model.decoder.embed_tokens.weight']]

In [109]:
with init_empty_weights():
  empty_model = AutoModelForCausalLM.from_config(config)

empty_model.lm_head.weight

Parameter containing:
tensor(..., device='meta', size=(50272, 512), requires_grad=True)

In [112]:
skip_modules = get_keys_to_not_convert(empty_model)
skip_modules

['model.decoder.embed_tokens', 'lm_head']

In [113]:
get_keys_to_not_convert(model)

['model.decoder.embed_tokens',
 'lm_head',
 'model.decoder.layers.23.final_layer_norm']

In [114]:
for module in skip_modules:
  parm = next(model_q8.get_submodule(module).parameters())
  print(f"{module}:{parm.dtype}")

model.decoder.embed_tokens:torch.float32
lm_head:torch.float32


In [121]:
## Using own list o modules to skip
# This configuration will raise an exception while trying to load weihgts for the tied layer

#bnb_config_skip = BitsAndBytesConfig(load_in_8bit = True, llm_int8_skip_modules= ['o_proj'])

bnb_config_skip = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=['o_proj', 'lm_head'])
model_skip = AutoModelForCausalLM.from_pretrained(
  "facebook/opt-350m", device_map = "cuda:0", torch_dtype=torch.float32, quantization_config=bnb_config_skip)

In [122]:
n_in = 10
n_out = 10

torch.manual_seed(11)
fp_layer = nn.Linear(n_in, n_out)
int8_layer = Linear8bitLt(n_in, n_out, has_fp16_weights=False)
int8_layer.load_state_dict(fp_layer.state_dict())
int8_layer.state_dict()

OrderedDict([('weight',
              tensor([[ 0.2844,  0.2170,  0.0247,  0.0281,  0.1041,  0.2937,  0.0831, -0.2136,
                        0.2078, -0.1361],
                      [-0.0537,  0.2445, -0.1245,  0.1865, -0.1038,  0.2362,  0.1284,  0.2510,
                        0.0729, -0.1195],
                      [-0.0383, -0.1476, -0.2729,  0.2769,  0.2600,  0.0114,  0.1547,  0.0714,
                        0.1445,  0.0250],
                      [ 0.0281, -0.1902, -0.1605, -0.1133,  0.1787,  0.1006, -0.1053,  0.1143,
                       -0.2415,  0.2174],
                      [ 0.0386,  0.0244,  0.1877,  0.0071,  0.2849, -0.0574,  0.0275,  0.1121,
                        0.0426, -0.1801],
                      [ 0.0083, -0.0654,  0.0756,  0.0439,  0.0812, -0.1807, -0.2128, -0.0198,
                        0.2100, -0.1630],
                      [ 0.0874, -0.2396,  0.2269, -0.2751,  0.2140,  0.2590,  0.2130, -0.2461,
                        0.2158, -0.3159],
                 

In [123]:
int8_layer =int8_layer.to(0)
int8_state = int8_layer.state_dict()
int8_state

OrderedDict([('weight',
              tensor([[ 123,   94,   11,   12,   45,  127,   36,  -92,   90,  -59],
                      [ -27,  124,  -63,   94,  -53,  120,   65,  127,   37,  -60],
                      [ -18,  -68, -125,  127,  119,    5,   71,   33,   66,   11],
                      [  15, -100,  -84,  -60,   94,   53,  -55,   60, -127,  114],
                      [  17,   11,   84,    3,  127,  -26,   12,   50,   19,  -80],
                      [   5,  -39,   45,   26,   48, -108, -127,  -12,  125,  -97],
                      [  35,  -96,   91, -111,   86,  104,   86,  -99,   87, -127],
                      [ 127,   16,  114,  -37,   59,   73,   41, -115,   85,  107],
                      [ -95,  116,  -18,   11,  -11,   56,   34,    5, -127,  -65],
                      [-127,   18,  -94, -119, -112,   56,    5,  -99,  -74,  -12]],
                     device='cuda:0', dtype=torch.int8)),
             ('bias',
              tensor([-0.1973, -0.0102,  0.2695,  0.117

In [135]:
supported = torch.cuda.is_bf16_supported(including_emulation=False)
compute_dtype = (torch.bfloat16 if supported else torch.float32)
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant=True,

)


In [137]:
model_q4 = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map = "cuda:0",
    quantization_config=nf4_config
)

print(model_q4.get_memory_footprint()/1e6, get_parm_dtypes(model_q4.parameters()))

207.835136 [(torch.float16, 242), (torch.uint8, 146)]


In [138]:
out = model_q4(**batch)
out.loss



tensor(4.4492, device='cuda:0', grad_fn=<NllLossBackward0>)

In [139]:
dec_layer = model_q4.model.decoder.layers[0]
dec_layer

OPTDecoderLayer(
  (self_attn): OPTSdpaAttention(
    (k_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)
    (q_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)
    (out_proj): Linear4bit(in_features=1024, out_features=1024, bias=True)
  )
  (activation_fn): ReLU()
  (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear4bit(in_features=1024, out_features=4096, bias=True)
  (fc2): Linear4bit(in_features=4096, out_features=1024, bias=True)
  (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

In [140]:
q4_layer = dec_layer.self_attn.k_proj
q4_layer

Linear4bit(in_features=1024, out_features=1024, bias=True)

In [141]:
q4_layer.state_dict()

OrderedDict([('weight',
              tensor([[ 32],
                      [ 29],
                      [208],
                      ...,
                      [ 66],
                      [ 34],
                      [172]], device='cuda:0', dtype=torch.uint8)),
             ('bias',
              tensor([-0.0134,  0.0082,  0.0161,  ..., -0.0242, -0.0150,  0.0203],
                     device='cuda:0')),
             ('weight.absmax',
              tensor([255, 255,   0,  ...,   0,   0, 255], device='cuda:0',
                     dtype=torch.uint8)),
             ('weight.quant_map',
              tensor([-1.0000, -0.6963, -0.5249, -0.3950, -0.2844, -0.1848, -0.0911,  0.0000,
                       0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5625,  0.7231,  1.0000],
                     device='cuda:0')),
             ('weight.nested_absmax',
              tensor([0.0077, 0.0142, 0.0153, 0.0138, 0.0399, 0.0409, 0.0417, 0.0426, 0.0053,
                      0.0053, 0.0053, 0.0053, 0

### FP4 VS NF4 LAYERS

In [142]:
n_in  = 10
n_out = 10
torch.manual_seed(11)
fp16_layer = nn.Linear(n_in, n_out)
fp16_layer

Linear(in_features=10, out_features=10, bias=True)

In [143]:
fp4_layer = LinearFP4(n_in, n_out)
fp4_layer.load_state_dict(fp16_layer.state_dict())

nf4_model = LinearNF4(n_in, n_out)
nf4_model.load_state_dict(fp16_layer.state_dict())

<All keys matched successfully>

In [144]:
fp4_layer = LinearFP4(n_in, n_out)
fp4_layer.load_state_dict(fp16_layer.state_dict())

<All keys matched successfully>

In [147]:
nf4_model.state_dict()

OrderedDict([('weight',
              tensor([[ 0.2844,  0.2170,  0.0247,  0.0281,  0.1041,  0.2937,  0.0831, -0.2136,
                        0.2078, -0.1361],
                      [-0.0537,  0.2445, -0.1245,  0.1865, -0.1038,  0.2362,  0.1284,  0.2510,
                        0.0729, -0.1195],
                      [-0.0383, -0.1476, -0.2729,  0.2769,  0.2600,  0.0114,  0.1547,  0.0714,
                        0.1445,  0.0250],
                      [ 0.0281, -0.1902, -0.1605, -0.1133,  0.1787,  0.1006, -0.1053,  0.1143,
                       -0.2415,  0.2174],
                      [ 0.0386,  0.0244,  0.1877,  0.0071,  0.2849, -0.0574,  0.0275,  0.1121,
                        0.0426, -0.1801],
                      [ 0.0083, -0.0654,  0.0756,  0.0439,  0.0812, -0.1807, -0.2128, -0.0198,
                        0.2100, -0.1630],
                      [ 0.0874, -0.2396,  0.2269, -0.2751,  0.2140,  0.2590,  0.2130, -0.2461,
                        0.2158, -0.3159],
                 

In [150]:
fp4_layer = fp4_layer.to(0)
fp4_state = fp4_layer.state_dict()
fp4_state['weight.quant_map'], fp4_state['weight'].shape

(tensor([ 0.0000,  0.0052,  0.6665,  1.0000,  0.3333,  0.5000,  0.1666,  0.2500,
          0.0000, -0.0052, -0.6665, -1.0000, -0.3333, -0.5000, -0.1666, -0.2500],
        device='cuda:0'),
 torch.Size([50, 1]))

In [151]:
nf4_model = nf4_model.to(0)
nf4_state = nf4_model.state_dict()
nf4_state['weight.quant_map'], nf4_state['weight'].shape

(tensor([-1.0000, -0.6963, -0.5249, -0.3950, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5625,  0.7231,  1.0000],
        device='cuda:0'),
 torch.Size([50, 1]))

In [153]:
supported = torch.cuda.is_bf16_supported(including_emulation=False)
compute_dtype = (torch.bfloat16 if supported else torch.float32)
nf4_config  =BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype = compute_dtype
)

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", device_map = "cuda:0", torch_dtype=torch.float32,
    quantization_config  = nf4_config
)

In [155]:
model.state_dict()

OrderedDict([('model.decoder.embed_tokens.weight',
              tensor([[-0.0353,  0.0629, -0.0628,  ..., -0.0625,  0.0188,  0.0313],
                      [ 0.0213,  0.0379, -0.0625,  ..., -0.0625, -0.0167,  0.0313],
                      [-0.0484, -0.0648,  0.0690,  ...,  0.0656, -0.0626, -0.0485],
                      ...,
                      [ 0.0723,  0.0312, -0.0634,  ..., -0.0625, -0.0053, -0.0755],
                      [ 0.0596, -0.0695, -0.0626,  ...,  0.0736, -0.0040,  0.0409],
                      [-0.0237,  0.0327, -0.0636,  ..., -0.0625, -0.0248,  0.0315]],
                     device='cuda:0', dtype=torch.float32)),
             ('model.decoder.embed_positions.weight',
              tensor([[-0.0066, -0.0121, -0.0097,  ..., -0.0013,  0.0037, -0.0047],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.0211, -0.0379, -0.0188,  ...,  0.0145,  0.0212,  0.0219],
                      ...,
                      [