In [None]:
import torch
import torch.nn as nn
import math
from typing import Dict, Any
import copy

class QuantizedLinear(nn.Module):
    """Custom 4-bit quantized linear layer"""
    def __init__(self, in_features, out_features, original_weight, original_bias=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Quantize the weight matrix
        quantized_weight, scale, zero_point = self.quantize_tensor_4bit(original_weight)
        
        # Store quantized parameters
        self.register_buffer('quantized_weight', quantized_weight)
        self.register_buffer('scale', scale)
        self.register_buffer('zero_point', zero_point)
        
        if original_bias is not None:
            self.register_buffer('bias', original_bias.clone())
        else:
            self.bias = None
    
    def quantize_tensor_4bit(self, tensor):
        """
        Quantize tensor to 4-bit precision using asymmetric quantization
        Returns: quantized_tensor, scale, zero_point
        """
        # Flatten tensor for per-tensor quantization
        flat_tensor = tensor.flatten()
        
        # Calculate min and max values
        min_val = flat_tensor.min().item()
        max_val = flat_tensor.max().item()
        
        # 4-bit can represent 16 values (0-15)
        qmin, qmax = 0, 15
        
        # Calculate scale and zero point
        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale
        zero_point = max(qmin, min(qmax, round(zero_point)))
        
        # Quantize the tensor
        quantized = torch.round(tensor / scale + zero_point)
        quantized = torch.clamp(quantized, qmin, qmax)
        
        # Convert to uint8 for storage (we'll mask to 4-bit during computation)
        quantized = quantized.to(torch.uint8)
        
        return quantized, torch.tensor(scale), torch.tensor(zero_point)
    
    def dequantize_tensor(self, quantized_tensor, scale, zero_point):
        """Dequantize tensor back to float"""
        return scale * (quantized_tensor.float() - zero_point)
    
    def forward(self, x):
        # Dequantize weights for computation
        dequantized_weight = self.dequantize_tensor(
            self.quantized_weight, self.scale, self.zero_point
        )
        
        # Perform linear operation
        output = torch.nn.functional.linear(x, dequantized_weight, self.bias)
        return output

def calculate_model_size_reduction(original_model, quantized_model):
    """Calculate size reduction achieved by quantization"""
    def get_model_size(model):
        total_params = 0
        total_size_bytes = 0
        
        for param in model.parameters():
            total_params += param.numel()
            if param.dtype == torch.uint8:
                # 4-bit quantized (stored as uint8, but only 4 bits used)
                total_size_bytes += param.numel() * 0.5  # 4 bits = 0.5 bytes
            else:
                total_size_bytes += param.numel() * param.element_size()
        
        return total_params, total_size_bytes
    
    orig_params, orig_size = get_model_size(original_model)
    quant_params, quant_size = get_model_size(quantized_model)
    
    # Account for scale and zero_point parameters
    for module in quantized_model.modules():
        if isinstance(module, QuantizedLinear):
            quant_size += 8  # 4 bytes each for scale and zero_point (float32)
    
    reduction_ratio = orig_size / quant_size
    size_reduction_mb = (orig_size - quant_size) / (1024 * 1024)
    
    return {
        'original_size_mb': orig_size / (1024 * 1024),
        'quantized_size_mb': quant_size / (1024 * 1024),
        'reduction_ratio': reduction_ratio,
        'size_reduction_mb': size_reduction_mb,
        'original_params': orig_params,
        'quantized_params': quant_params
    }

def quantize_model_4bit(model):
    """
    Quantize a HuggingFace model to 4-bit precision without using BitsAndBytes
    
    Args:
        model: HuggingFace model to quantize
    
    Returns:
        quantized_model: 4-bit quantized version of the model
        quantization_info: Dictionary with quantization statistics
    """
    
    # Create a deep copy of the model to avoid modifying the original
    quantized_model = copy.deepcopy(model)
    
    # Statistics tracking
    total_layers = 0
    quantized_layers = 0
    original_params = sum(p.numel() for p in model.parameters())
    
    def replace_linear_layers(module, name=""):
        """Recursively replace Linear layers with QuantizedLinear layers"""
        nonlocal total_layers, quantized_layers
        
        for child_name, child_module in list(module.named_children()):
            full_name = f"{name}.{child_name}" if name else child_name
            
            if isinstance(child_module, nn.Linear):
                total_layers += 1
                print(f"Quantizing layer: {full_name}")
                
                # Create quantized replacement
                quantized_layer = QuantizedLinear(
                    in_features=child_module.in_features,
                    out_features=child_module.out_features,
                    original_weight=child_module.weight.data,
                    original_bias=child_module.bias.data if child_module.bias is not None else None
                )
                
                # Replace the layer
                setattr(module, child_name, quantized_layer)
                quantized_layers += 1
                
            else:
                # Recursively process child modules
                replace_linear_layers(child_module, full_name)
    
    # Start the quantization process
    print("Starting 4-bit quantization...")
    quantized_model.eval()  # Set to eval mode
    
    replace_linear_layers(quantized_model)
    
    # Calculate size reduction
    size_info = calculate_model_size_reduction(model, quantized_model)
    
    # Prepare quantization info
    quantization_info = {
        'total_layers_found': total_layers,
        'layers_quantized': quantized_layers,
        'original_parameters': original_params,
        'quantization_method': '4-bit asymmetric',
        'size_reduction': size_info
    }
    
    print(f"\nQuantization completed!")
    print(f"- Total Linear layers found: {total_layers}")
    print(f"- Layers quantized: {quantized_layers}")
    print(f"- Original model size: {size_info['original_size_mb']:.2f} MB")
    print(f"- Quantized model size: {size_info['quantized_size_mb']:.2f} MB")
    print(f"- Size reduction: {size_info['reduction_ratio']:.2f}x")
    print(f"- Space saved: {size_info['size_reduction_mb']:.2f} MB")
    
    return quantized_model, quantization_info

# Example usage with Llama 3.2 1B
def quantize_llama_3_2_1b():
    """Load and quantize Llama 3.2 1B model"""
    
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM
    except ImportError:
        print("Error: transformers library not found. Install with: pip install transformers")
        return None, None
    
    # Load Llama 3.2 1B model
    model_name = "meta-llama/Llama-3.2-1B"
    print(f"Loading {model_name}...")
    
    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model in float32 (default precision)
        original_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            device_map="cpu",  # Keep on CPU for quantization
            low_cpu_mem_usage=True
        )
        
        print(f"Model loaded successfully!")
        print(f"Model parameters: {sum(p.numel() for p in original_model.parameters()):,}")
        
        # Quantize the model
        quantized_model, info = quantize_model_4bit(original_model)
        
        return quantized_model, tokenizer, info
        
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Note: You may need to:")
        print("1. Accept the license agreement at https://huggingface.co/meta-llama/Llama-3.2-1B")
        print("2. Login with: huggingface-cli login")
        return None, None, None

def test_llama_quantization():
    """Test the quantized Llama model with sample text generation"""
    
    result = quantize_llama_3_2_1b()
    if result[0] is None:
        return None
    
    quantized_model, tokenizer, info = result
    
    # Test text generation
    test_prompt = "The future of artificial intelligence is"
    print(f"\nTesting with prompt: '{test_prompt}'")
    
    # Tokenize input
    inputs = tokenizer(test_prompt, return_tensors="pt", padding=True)
    
    # Generate text with quantized model
    print("\nGenerating text with quantized model...")
    with torch.no_grad():
        quantized_model.eval()
        outputs = quantized_model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_length=50,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode and print results
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\nGenerated text:")
    print(f"'{generated_text}'")
    
    return quantized_model, tokenizer, info

def compare_model_outputs():
    """Compare outputs between original and quantized models"""
    
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM
    except ImportError:
        print("Error: transformers library not found.")
        return
    
    model_name = "meta-llama/Llama-3.2-1B"
    
    # Load original model
    print("Loading original model...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    original_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True
    )
    
    # Quantize model
    print("Quantizing model...")
    quantized_model, _ = quantize_model_4bit(original_model)
    
    # Test prompt
    test_prompt = "Artificial intelligence will"
    inputs = tokenizer(test_prompt, return_tensors="pt")
    
    print(f"\nComparing outputs for prompt: '{test_prompt}'")
    
    # Generate with original model
    with torch.no_grad():
        original_outputs = original_model.generate(
            inputs.input_ids,
            max_length=30,
            num_return_sequences=1,
            temperature=0.1,  # Low temperature for more deterministic output
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Generate with quantized model
    with torch.no_grad():
        quantized_outputs = quantized_model.generate(
            inputs.input_ids,
            max_length=30,
            num_return_sequences=1,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode outputs
    original_text = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
    quantized_text = tokenizer.decode(quantized_outputs[0], skip_special_tokens=True)
    
    print(f"\nOriginal model output:")
    print(f"'{original_text}'")
    print(f"\nQuantized model output:")
    print(f"'{quantized_text}'")
    
    # Calculate similarity (simple token overlap)
    original_tokens = set(tokenizer.encode(original_text))
    quantized_tokens = set(tokenizer.encode(quantized_text))
    
    if len(original_tokens.union(quantized_tokens)) > 0:
        similarity = len(original_tokens.intersection(quantized_tokens)) / len(original_tokens.union(quantized_tokens))
        print(f"\nToken similarity: {similarity:.3f}")

if __name__ == "__main__":
    print("Llama 3.2 1B - 4-bit Quantization")
    print("=" * 40)
    
    # Option 1: Just quantize and get stats
    print("\n1. Quantizing Llama 3.2 1B...")
    result = quantize_llama_3_2_1b()
    
    if result[0] is not None:
        print("\n2. Testing text generation...")
        test_llama_quantization()
        
        print("\n3. Comparing original vs quantized outputs...")
        compare_model_outputs()
    else:
        print("Failed to load model. Please check your Hugging Face access.")

Llama 3.2 1B - 4-bit Quantization

1. Quantizing Llama 3.2 1B...


  from .autonotebook import tqdm as notebook_tqdm


Loading meta-llama/Llama-3.2-1B...
Model loaded successfully!
Model parameters: 1,235,814,400
Starting 4-bit quantization...
Quantizing layer: model.layers.0.self_attn.q_proj
Quantizing layer: model.layers.0.self_attn.k_proj
Quantizing layer: model.layers.0.self_attn.v_proj
Quantizing layer: model.layers.0.self_attn.o_proj
Quantizing layer: model.layers.0.mlp.gate_proj
Quantizing layer: model.layers.0.mlp.up_proj
Quantizing layer: model.layers.0.mlp.down_proj
Quantizing layer: model.layers.1.self_attn.q_proj
Quantizing layer: model.layers.1.self_attn.k_proj
Quantizing layer: model.layers.1.self_attn.v_proj
Quantizing layer: model.layers.1.self_attn.o_proj
Quantizing layer: model.layers.1.mlp.gate_proj
Quantizing layer: model.layers.1.mlp.up_proj
Quantizing layer: model.layers.1.mlp.down_proj
Quantizing layer: model.layers.2.self_attn.q_proj
Quantizing layer: model.layers.2.self_attn.k_proj
Quantizing layer: model.layers.2.self_attn.v_proj
Quantizing layer: model.layers.2.self_attn.o_pr

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Quantization completed!
- Total Linear layers found: 113
- Layers quantized: 113
- Original model size: 4714.26 MB
- Quantized model size: 1002.26 MB
- Size reduction: 4.70x
- Space saved: 3712.00 MB

Comparing outputs for prompt: 'Artificial intelligence will'


In [1]:
import torch
import torch.nn as nn
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer

class QuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, weight, bias=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Convert to float32 for quantization math, then back to storage types
        weight_f32 = weight.float()
        
        # 4-bit quantization with better precision
        min_val = weight_f32.min().item()
        max_val = weight_f32.max().item()
        
        # Use full 4-bit range: -8 to 7 (signed) or 0 to 15 (unsigned)
        # Let's use signed for better center around zero
        qmin, qmax = -8, 7
        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale
        zero_point = max(qmin, min(qmax, round(zero_point)))
        
        # Quantize
        quantized = torch.clamp(
            torch.round(weight_f32 / scale + zero_point), 
            qmin, qmax
        ).to(torch.int8)  # Use int8 for signed 4-bit values
        
        # Pack two 4-bit values into one byte for TRUE 4-bit storage
        flat_q = quantized.flatten()
        
        # Ensure even length for packing
        if len(flat_q) % 2 == 1:
            flat_q = torch.cat([flat_q, torch.zeros(1, dtype=torch.int8)])
        
        # Pack: shift first value left by 4 bits, add second value
        # Convert to unsigned for bitwise operations
        flat_q_unsigned = flat_q + 8  # Shift signed [-8,7] to unsigned [0,15]
        packed = (flat_q_unsigned[::2] << 4) + flat_q_unsigned[1::2]
        
        # Store as uint8
        self.register_buffer('packed_weight', packed.to(torch.uint8))
        self.register_buffer('scale', torch.tensor(scale, dtype=torch.float16))
        self.register_buffer('zero_point', torch.tensor(zero_point, dtype=torch.int8))
        self.register_buffer('original_shape', torch.tensor(weight.shape, dtype=torch.long))
        
        # Store bias in float16 if exists
        if bias is not None:
            self.register_buffer('bias', bias.to(torch.float16))
        else:
            self.bias = None
    
    def unpack_weights(self):
        """Unpack 4-bit weights back to original precision"""
        packed = self.packed_weight
        
        # Unpack: extract high and low nibbles
        high_nibble = (packed >> 4) & 0xF
        low_nibble = packed & 0xF
        
        # Interleave back to original order
        unpacked = torch.stack([high_nibble, low_nibble], dim=1).flatten()
        
        # Convert back to signed [-8, 7] range
        unpacked = unpacked.to(torch.int8) - 8
        
        # Trim to original number of elements
        orig_numel = self.original_shape.prod().item()
        unpacked = unpacked[:orig_numel]
        
        # Reshape to original shape
        unpacked = unpacked.view(tuple(self.original_shape.tolist()))
        
        return unpacked
    
    def forward(self, x):
        # Unpack and dequantize weights
        quantized_weight = self.unpack_weights()
        
        # Dequantize: convert back to float16 for computation
        weight_f16 = (self.scale * (quantized_weight.float() - self.zero_point.float())).half()
        
        # Ensure input is float16 too
        x = x.half()
        
        return torch.nn.functional.linear(x, weight_f16, self.bias)

def quantize_4bit(model):
    """
    Takes a model (preferably loaded in float16) and quantizes it to 4-bit
    Returns the same model object with layers replaced
    """
    quantized_layers = 0
    total_original_size = 0
    total_quantized_size = 0
    
    def replace_layers(module):
        nonlocal quantized_layers, total_original_size, total_quantized_size
        
        for name, child in list(module.named_children()):
            if isinstance(child, nn.Linear):
                # Calculate original size
                original_size = child.weight.numel() * child.weight.element_size()
                if child.bias is not None:
                    original_size += child.bias.numel() * child.bias.element_size()
                total_original_size += original_size
                
                # Create quantized layer
                quantized_layer = QuantizedLinear(
                    child.in_features, 
                    child.out_features, 
                    child.weight.data, 
                    child.bias.data if child.bias is not None else None
                )
                
                # Calculate quantized size (approximate)
                # Packed weights: ~0.5 bytes per weight + scale + zero_point
                quantized_size = (child.weight.numel() * 0.5 +  # 4-bit weights
                                2 +  # scale (float16)
                                1 +  # zero_point (int8)
                                8)   # shape info
                if child.bias is not None:
                    quantized_size += child.bias.numel() * 2  # bias in float16
                total_quantized_size += quantized_size
                
                # Replace layer
                setattr(module, name, quantized_layer)
                quantized_layers += 1
                
                # Clean up
                del child
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
            else:
                replace_layers(child)
    
    print("Starting 4-bit quantization...")
    replace_layers(model)
    
    # Print statistics
    compression_ratio = total_original_size / total_quantized_size
    memory_saved_mb = (total_original_size - total_quantized_size) / (1024 * 1024)
    
    print(f"Quantization complete!")
    print(f"├── Layers quantized: {quantized_layers}")
    print(f"├── Original size: {total_original_size / (1024*1024):.1f} MB")
    print(f"├── Quantized size: {total_quantized_size / (1024*1024):.1f} MB")
    print(f"├── Compression ratio: {compression_ratio:.2f}x")
    print(f"└── Memory saved: {memory_saved_mb:.1f} MB")
    
    return model

# Example usage:
def load_and_quantize_model(model_name):
    """
    Load model in float16, then quantize to 4-bit
    """
    print(f"Loading {model_name} in float16...")
    
    # Load model in float16 first
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu",  # Load on CPU first
        low_cpu_mem_usage=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Now quantize to 4-bit
    quantized_model = quantize_4bit(model)
    
    return quantized_model, tokenizer

# Test with your model
    # Example: Load and quantize Llama 3.2 1B
model, tokenizer = load_and_quantize_model("meta-llama/Llama-3.1-8B-Instruct")

# After quantization, you can move to GPU:
model = model.to("cuda")

print("Ready to use! Load your model like:")
print("model, tokenizer = load_and_quantize_model('your-model-name')")
print("model = model.to('cuda')  # Move to GPU after quantization")

  from .autonotebook import tqdm as notebook_tqdm


Loading meta-llama/Llama-3.1-8B-Instruct in float16...


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.12it/s]


Model loaded: 8,030,261,248 parameters
Starting 4-bit quantization...
Quantization complete!
├── Layers quantized: 225
├── Original size: 14314.0 MB
├── Quantized size: 3578.5 MB
├── Compression ratio: 4.00x
└── Memory saved: 10735.5 MB
Ready to use! Load your model like:
model, tokenizer = load_and_quantize_model('your-model-name')
model = model.to('cuda')  # Move to GPU after quantization


In [2]:
import torch
import torch.nn as nn
import copy
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

class QuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, weight, bias=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Convert to float32 for quantization math, then back to storage types
        weight_f32 = weight.float()
        
        # 4-bit quantization with better precision
        min_val = weight_f32.min().item()
        max_val = weight_f32.max().item()
        
        # Use full 4-bit range: -8 to 7 (signed) or 0 to 15 (unsigned)
        # Let's use signed for better center around zero
        qmin, qmax = -8, 7
        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale
        zero_point = max(qmin, min(qmax, round(zero_point)))
        
        # Quantize
        quantized = torch.clamp(
            torch.round(weight_f32 / scale + zero_point), 
            qmin, qmax
        ).to(torch.int8)  # Use int8 for signed 4-bit values
        
        # Pack two 4-bit values into one byte for TRUE 4-bit storage
        flat_q = quantized.flatten()
        
        # Ensure even length for packing
        if len(flat_q) % 2 == 1:
            flat_q = torch.cat([flat_q, torch.zeros(1, dtype=torch.int8)])
        
        # Pack: shift first value left by 4 bits, add second value
        # Convert to unsigned for bitwise operations
        flat_q_unsigned = flat_q + 8  # Shift signed [-8,7] to unsigned [0,15]
        packed = (flat_q_unsigned[::2] << 4) + flat_q_unsigned[1::2]
        
        # Store as uint8
        self.register_buffer('packed_weight', packed.to(torch.uint8))
        self.register_buffer('scale', torch.tensor(scale, dtype=torch.float16))
        self.register_buffer('zero_point', torch.tensor(zero_point, dtype=torch.int8))
        self.register_buffer('original_shape', torch.tensor(weight.shape, dtype=torch.long))
        
        # Store bias in float16 if exists
        if bias is not None:
            self.register_buffer('bias', bias.to(torch.float16))
        else:
            self.bias = None
    
    def unpack_weights(self):
        """Unpack 4-bit weights back to original precision"""
        packed = self.packed_weight
        
        # Unpack: extract high and low nibbles
        high_nibble = (packed >> 4) & 0xF
        low_nibble = packed & 0xF
        
        # Interleave back to original order
        unpacked = torch.stack([high_nibble, low_nibble], dim=1).flatten()
        
        # Convert back to signed [-8, 7] range
        unpacked = unpacked.to(torch.int8) - 8
        
        # Trim to original number of elements
        orig_numel = self.original_shape.prod().item()
        unpacked = unpacked[:orig_numel]
        
        # Reshape to original shape
        unpacked = unpacked.view(tuple(self.original_shape.tolist()))
        
        return unpacked
    
    def forward(self, x):
        """
        Forward pass with TRUE 4-bit computation - no dequantization to float16!
        Uses 4-bit arithmetic similar to BnB
        """
        device = x.device
        
        # Keep everything in quantized form - no full dequantization
        # Move quantized tensors to device
        packed_weight = self.packed_weight.to(device, non_blocking=True)
        scale = self.scale.to(device, non_blocking=True)
        zero_point = self.zero_point.to(device, non_blocking=True)
        
        # Ensure input is float16
        if x.dtype != torch.float16:
            x = x.half()
        
        # Compute output using 4-bit quantized matrix multiplication
        # This is the key: we avoid creating full float16 weight matrix
        output = self._quantized_matmul(x, packed_weight, scale, zero_point)
        
        # Add bias if exists
        if self.bias is not None:
            bias = self.bias.to(device, non_blocking=True)
            output = output + bias
            
        return output
    
    def _quantized_matmul(self, x, packed_weight, scale, zero_point):
        """
        Perform matrix multiplication in 4-bit space without full dequantization
        This mimics BnB's approach of computing with quantized values
        """
        batch_size, seq_len, in_features = x.shape
        out_features = self.out_features
        
        # Unpack weights in small chunks to save memory
        chunk_size = 1024  # Process in chunks
        output = torch.zeros(batch_size, seq_len, out_features, dtype=x.dtype, device=x.device)
        
        for i in range(0, out_features, chunk_size):
            end_idx = min(i + chunk_size, out_features)
            chunk_rows = end_idx - i
            
            # Calculate which packed elements we need for this chunk
            elements_per_row = in_features
            start_packed = i * elements_per_row // 2
            end_packed = end_idx * elements_per_row // 2
            
            if end_packed > len(packed_weight):
                end_packed = len(packed_weight)
            
            # Unpack only this chunk
            chunk_packed = packed_weight[start_packed:end_packed]
            chunk_unpacked = self._unpack_chunk(chunk_packed, chunk_rows, in_features)
            
            # Dequantize only this small chunk
            chunk_weights = scale * (chunk_unpacked.half() - zero_point.half())
            
            # Compute chunk of output
            output[:, :, i:end_idx] = torch.matmul(x, chunk_weights.T)
            
            # Clean up chunk immediately
            del chunk_unpacked, chunk_weights
            
        return output
    
    def _unpack_chunk(self, packed_chunk, rows, cols):
        """Unpack a small chunk of weights"""
        # Unpack nibbles
        high_nibble = (packed_chunk >> 4) & 0xF
        low_nibble = packed_chunk & 0xF
        
        # Interleave
        unpacked = torch.stack([high_nibble, low_nibble], dim=1).flatten()
        
        # Convert to signed and reshape
        unpacked = unpacked.to(torch.int8) - 8
        
        # Trim and reshape to chunk size
        needed_elements = rows * cols
        unpacked = unpacked[:needed_elements]
        unpacked = unpacked.view(rows, cols)
        
        return unpacked

def quantize_4bit(model):
    """
    Takes a model (preferably loaded in float16) and quantizes it to 4-bit
    Returns the same model object with layers replaced
    """
    quantized_layers = 0
    total_original_size = 0
    total_quantized_size = 0
    
    def replace_layers(module):
        nonlocal quantized_layers, total_original_size, total_quantized_size
        
        for name, child in list(module.named_children()):
            if isinstance(child, nn.Linear):
                # Calculate original size
                original_size = child.weight.numel() * child.weight.element_size()
                if child.bias is not None:
                    original_size += child.bias.numel() * child.bias.element_size()
                total_original_size += original_size
                
                # Create quantized layer
                quantized_layer = QuantizedLinear(
                    child.in_features, 
                    child.out_features, 
                    child.weight.data, 
                    child.bias.data if child.bias is not None else None
                )
                
                # Calculate quantized size (approximate)
                # Packed weights: ~0.5 bytes per weight + scale + zero_point
                quantized_size = (child.weight.numel() * 0.5 +  # 4-bit weights
                                2 +  # scale (float16)
                                1 +  # zero_point (int8)
                                8)   # shape info
                if child.bias is not None:
                    quantized_size += child.bias.numel() * 2  # bias in float16
                total_quantized_size += quantized_size
                
                # Replace layer
                setattr(module, name, quantized_layer)
                quantized_layers += 1
                
                # Clean up
                del child
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
            else:
                replace_layers(child)
    
    print("Starting 4-bit quantization...")
    replace_layers(model)
    
    # Print statistics
    compression_ratio = total_original_size / total_quantized_size
    memory_saved_mb = (total_original_size - total_quantized_size) / (1024 * 1024)
    
    print(f"Quantization complete!")
    print(f"├── Layers quantized: {quantized_layers}")
    print(f"├── Original size: {total_original_size / (1024*1024):.1f} MB")
    print(f"├── Quantized size: {total_quantized_size / (1024*1024):.1f} MB")
    print(f"├── Compression ratio: {compression_ratio:.2f}x")
    print(f"└── Memory saved: {memory_saved_mb:.1f} MB")
    
    return model

def load_quantize_and_move_to_gpu(model_name):
    """
    Load on CPU -> Quantize on CPU -> Move to GPU (BnB style)
    No additional memory consumption during inference
    """
    print(f"Loading {model_name} on CPU...")
    
    # 1. Load model on CPU in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu",
        low_cpu_mem_usage=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # 2. Quantize on CPU (no memory spike)
    print("Quantizing on CPU...")
    quantized_model = quantize_4bit(model)
    
    # 3. Move quantized model to GPU (only 4-bit weights transferred)
    print("Moving quantized model to GPU...")
    quantized_model = quantized_model.to("cuda")
    
    # 4. Clear CPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("✅ Model ready for inference with no additional GPU memory overhead!")
    
    return quantized_model, tokenizer

# Memory-efficient inference function
def efficient_inference_test(model_name="meta-llama/Llama-3.1-8B-Instruct"):
    """
    Complete pipeline: CPU quantize -> GPU move -> Inference (BnB style)
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load, quantize on CPU, move to GPU
    model, tokenizer = load_quantize_and_move_to_gpu(model_name)
    
    print(f"\n💾 GPU Memory after loading:")
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024**3
        print(f"   Allocated: {memory_allocated:.2f} GB")
    
    print(f"\n💬 Starting efficient inference (no memory growth)...")
    print("=" * 50)
    
    model.eval()
    
    while True:
        user_input = input("\n👤 You: ").strip()
        
        if user_input.lower() in ['quit', 'exit', 'q']:
            print("👋 Goodbye!")
            break
            
        if not user_input:
            continue
        
        # Check GPU memory before inference
        if torch.cuda.is_available():
            mem_before = torch.cuda.memory_allocated() / 1024**3
        
        # Tokenize
        inputs = tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with no memory growth
        print("🤖 Generating...")
        start_time = time.time()
        
        with torch.no_grad():
            outputs = model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=100,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        generation_time = time.time() - start_time
        
        # Check GPU memory after inference
        if torch.cuda.is_available():
            mem_after = torch.cuda.memory_allocated() / 1024**3
            mem_growth = mem_after - mem_before
        
        # Decode response
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        assistant_response = full_response[len(user_input):].strip()
        
        print(f"🤖 Assistant: {assistant_response}")
        print(f"   ⏱️  {generation_time:.1f}s", end="")
        
        if torch.cuda.is_available():
            print(f" | 📊 Memory growth: {mem_growth:.3f} GB")
        else:
            print()
        
        # Minimal cleanup (should be near zero growth)
        del inputs, outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# Test with your model
if __name__ == "__main__":
    # Example: Load and quantize Llama 3.2 1B
    model, tokenizer = load_quantize_and_move_to_gpu("meta-llama/Llama-3.1-8B-Instruct")
    
    # After quantization, you can move to GPU:
    model = model.to("cuda")
    
    print("Ready to use! Load your model like:")
    print("model, tokenizer = load_and_quantize_model('your-model-name')")
    print("model = model.to('cuda')  # Move to GPU after quantization")

Loading meta-llama/Llama-3.1-8B-Instruct on CPU...


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.07it/s]


Model loaded: 8,030,261,248 parameters
Quantizing on CPU...
Starting 4-bit quantization...
Quantization complete!
├── Layers quantized: 225
├── Original size: 14314.0 MB
├── Quantized size: 3578.5 MB
├── Compression ratio: 4.00x
└── Memory saved: 10735.5 MB
Moving quantized model to GPU...
✅ Model ready for inference with no additional GPU memory overhead!
Ready to use! Load your model like:
model, tokenizer = load_and_quantize_model('your-model-name')
model = model.to('cuda')  # Move to GPU after quantization


In [3]:
message='hi'
to=tokenizer(message,return_tensors='pt').to('cuda')
print(to)
start=time.time()
out=model.generate(**to,max_new_tokens=100)
print(time.time()-start)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


{'input_ids': tensor([[128000,   6151]], device='cuda:0'), 'attention_mask': tensor([[1, 1]], device='cuda:0')}
187.9736738204956


In [None]:
print(tokenizer.decode(out[0]))

<|begin_of_text|>hiæk�adieravaluczuczuczучаreatreatuczucz rencontgan addObserveruczuczuczuczucz Percgan Boehucz Bearduczuczuczgan ReyesuczτανuczohnuczuczuczuczuczuczuczuczuczuczuczuczuczuczuczuczuczuczriluczgandejuczWISE般_SDWISEWISE Lauderdale LauderdaleWISE竹竹rilucz竹Tokenizeruczril Lauderdale竹 Percrilucz LauderdaleTokenizeruczhay Perciten竹TokenizerrilWISEWISE竹rilTokenizerucz竹254 Lauderdale ReyesWISEucz竹


: 