In [1]:
# packages
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import safetensors
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# base model id
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# checkpoint after QLoRA training
checkpoint_path = "./qlora_results/checkpoint-75"

In [3]:
# configuration for loading the model in 4-bit quantized form with NF4 quantization format
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [4]:
# Load base model with the above quantization config
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)


Loading checkpoint shards: 100%|██████████| 3/3 [00:16<00:00,  5.38s/it]


In [5]:
# Inspect the fine tuned QLoRA model from checkpoint
model = PeftModel.from_pretrained(base_model, checkpoint_path)

In [6]:
# safetensors file
adapter_file = f"{checkpoint_path}/adapter_model.safetensors"
# reading safetensors file
with safetensors.safe_open(adapter_file, framework="pt") as f:
    adapter_weights = {}
    for k in f.keys():
        adapter_weights[k] = f.get_tensor(k)

print("found {} adapter weights".format(len(adapter_weights)))

found 448 adapter weights


In [7]:
# inspecting adapter weights from checkpoint by lora_A and lora_B
adapter_groups = {}
for name, tensor in adapter_weights.items():
    parts = name.split('.')
    if 'lora_A' in name or 'lora_B' in name:
        base_name = '.'.join(parts[:-2])
        if base_name not in adapter_groups:
            adapter_groups[base_name] = {}
        adapter_groups[base_name][parts[-2]] = tensor

# Number of adapter groups
print(f"Adapter groups: {len(adapter_groups)}")

# printing information for first three adapter groups
for i, (module_name, adapters) in enumerate(list(adapter_groups.items())[:3]):
    print(f"\nAdapter Group {i+1}: {module_name}")
    if 'lora_A' in adapters:
        lora_A = adapters['lora_A']
        print(f"LoRA-A (down-projection):")
        print(f"Shape: {lora_A.shape}")
        print(f"Dtype: {lora_A.dtype}")
    if 'lora_B' in adapters:
        lora_B = adapters['lora_B']
        print(f"LoRA-B (up-projection):")
        print(f"Shape: {lora_B.shape}")
        print(f"Dtype: {lora_B.dtype}")

Adapter groups: 224

Adapter Group 1: base_model.model.model.layers.0.mlp.down_proj
LoRA-A (down-projection):
Shape: torch.Size([64, 14336])
Dtype: torch.float32
LoRA-B (up-projection):
Shape: torch.Size([4096, 64])
Dtype: torch.float32

Adapter Group 2: base_model.model.model.layers.0.mlp.gate_proj
LoRA-A (down-projection):
Shape: torch.Size([64, 4096])
Dtype: torch.float32
LoRA-B (up-projection):
Shape: torch.Size([14336, 64])
Dtype: torch.float32

Adapter Group 3: base_model.model.model.layers.0.mlp.up_proj
LoRA-A (down-projection):
Shape: torch.Size([64, 4096])
Dtype: torch.float32
LoRA-B (up-projection):
Shape: torch.Size([14336, 64])
Dtype: torch.float32


### Explanation on why we see int8 weights when we loaded the base model as int4

If you observe in the output we see in the weight type is int8 but we loaded the base model as int4. this is a common point of confusion when looking at the low-level memory details. 

The reason the weights appear as 8-bit integers even though we loaded the model with a 4-bit configuration is due to a technique called data packing, which is necessary because of how computer hardware and libraries like PyTorch handle memory.

The main problem is that computer hardware and programming libraries like PyTorch typically cannot address or store data in units smaller than an 8-bit (torch.unit8).

The smallest slot in memory you can put data into is an 8-bit unit. If the system stored a single 4-bit number in a full 8-bit slot, we would be wasting half of the memory. 

In [8]:
# loading base model
base_model = model.get_base_model()

# checking quantized layers
quantized_layers = []
for name, module in base_model.named_modules():
    if hasattr(module, 'weight') and hasattr(module, 'quant_state'):
        quantized_layers.append((name, module))

for i, (name, module) in enumerate(quantized_layers[:3]):
    print(f"\nQuantized Layer {i+1}: {name}")
    
    weight = module.weight
    print(f"Weight shape: {weight.shape}")
    print(f"Weight dtype: {weight.dtype}")
    print(f"Weight storage: {weight.data.dtype if hasattr(weight, 'data') else 'N/A'}")

    storage_elements = weight.data.numel() 

    total_logical_elements = storage_elements * 2 

    storage_bits = storage_elements * 8

    effective_bits = storage_bits / total_logical_elements

    print(f"Effective bits per weight: {effective_bits:.2f}")
    
    weight_memory = weight.numel() * 4
    quantized_memory = weight.data.numel() * 1 if hasattr(weight, 'data') else weight_memory
    print(f"Memory (fp32 equivalent): {weight_memory / 1024:.2f} KB")
    print(f"Memory (quantized): {quantized_memory / 1024:.2f} KB")
    print(f"Compression ratio: {weight_memory / quantized_memory:.1f}x")


Quantized Layer 1: model.layers.0.self_attn.q_proj.base_layer
Weight shape: torch.Size([8388608, 1])
Weight dtype: torch.uint8
Weight storage: torch.uint8
Effective bits per weight: 4.00
Memory (fp32 equivalent): 32768.00 KB
Memory (quantized): 8192.00 KB
Compression ratio: 4.0x

Quantized Layer 2: model.layers.0.self_attn.k_proj.base_layer
Weight shape: torch.Size([2097152, 1])
Weight dtype: torch.uint8
Weight storage: torch.uint8
Effective bits per weight: 4.00
Memory (fp32 equivalent): 8192.00 KB
Memory (quantized): 2048.00 KB
Compression ratio: 4.0x

Quantized Layer 3: model.layers.0.self_attn.v_proj.base_layer
Weight shape: torch.Size([2097152, 1])
Weight dtype: torch.uint8
Weight storage: torch.uint8
Effective bits per weight: 4.00
Memory (fp32 equivalent): 8192.00 KB
Memory (quantized): 2048.00 KB
Compression ratio: 4.0x


In [9]:
# --- Memory Usage Comparison ---
print("MEMORY USAGE COMPARISON")

# Calculate LoRA adapter memory
lora_memory = sum(tensor.numel() * tensor.element_size() for tensor in adapter_weights.values())

# Estimate base model memory (4-bit)
base_memory = 0
for name, param in base_model.named_parameters():
    if 'weight' in name and not any(lora_name in name for lora_name in ['lora_A', 'lora_B']):
        base_memory += param.numel() * param.element_size()

# Estimate what full model would be in fp16
full_model_memory = sum(p.numel() for p in base_model.parameters()) * 2  # fp16 bytes

print(f"Memory Usage:")
print(f"4-bit base model: {base_memory / 1024**2:.2f} MB")
print(f"16-bit LoRA adapters: {lora_memory / 1024**2:.2f} MB")
print(f"Total QLoRA model: {(base_memory + lora_memory) / 1024**2:.2f} MB")
print(f"Equivalent fp16 model: {full_model_memory / 1024**2:.2f} MB")
print(f"Memory savings: {(1 - (base_memory + lora_memory) / full_model_memory) * 100:.1f}%")

MEMORY USAGE COMPARISON
Memory Usage:
4-bit base model: 3828.51 MB
16-bit LoRA adapters: 640.00 MB
Total QLoRA model: 4468.51 MB
Equivalent fp16 model: 7476.51 MB
Memory savings: 40.2%
