<a href="https://colab.research.google.com/github/pierredantas/LLMCompress/blob/main/Claude_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.utils.prune as prune
import os
import numpy as np
import psutil
from copy import deepcopy

In [5]:
def count_parameters(model):
    """Count the number of non-zero parameters in the model"""
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            total_params += (param != 0).sum().item()  # Only count non-zero parameters
    return total_params

def get_model_size(model):
    """Get model size in MB, considering only non-zero elements"""
    # Save only non-zero elements
    sparse_state_dict = {}
    for name, param in model.state_dict().items():
        if 'weight' in name:
            non_zero_mask = param != 0
            values = param[non_zero_mask]
            indices = non_zero_mask.nonzero()
            sparse_state_dict[f"{name}_values"] = values
            sparse_state_dict[f"{name}_indices"] = indices
        else:
            sparse_state_dict[name] = param

    # Save and get size
    torch.save(sparse_state_dict, "temp.p")
    size = os.path.getsize("temp.p")/1e6
    os.remove('temp.p')
    return size

def get_memory_allocation():
    """Get current memory allocation in MB"""
    process = psutil.Process(os.getpid())
    memory_mb = process.memory_info().rss / (1024 * 1024)
    return memory_mb

def get_size_mb(path):
    """Get model size in MB from path"""
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size / (1024 * 1024)  # Convert bytes to MB

def calculate_sparsity(model):
    """Calculate the sparsity (percentage of zero weights) in the model"""
    zero_count = 0
    total_count = 0

    for param in model.parameters():
        if param.requires_grad:
            zero_count += torch.sum(param == 0).item()
            total_count += param.numel()

    return (zero_count / total_count) * 100 if total_count > 0 else 0

def print_model_stats(model, model_path=None):
    """Print comprehensive statistics about the model"""
    # Get all stats
    params = count_parameters(model)
    memory = get_memory_allocation()
    sparsity = calculate_sparsity(model)

    # Print all stats
    print(f"Number of parameters: {params/1e6:.2f}M")
    if model_path and os.path.exists(model_path):
        size = get_size_mb(model_path)
        print(f"Size: {size:.2f} MB")
    else:
        # If no path provided or path doesn't exist, calculate size from the model directly
        torch.save(model.state_dict(), "temp.p")
        size = os.path.getsize("temp.p") / (1024 * 1024)  # Convert to MB
        os.remove("temp.p")
        print(f"Size: {size:.2f} MB")
    print(f"Memory allocation: {memory:.2f} MB")
    print(f"Sparsity: {sparsity:.2f}%")

class SparseLinear(torch.nn.Linear):
    def forward(self, input):
        # Only use non-zero weights in forward pass
        sparse_weight = self.weight * (self.weight != 0)
        return torch.nn.functional.linear(input, sparse_weight, self.bias)

def global_pruning(model, pruning_threshold):
    """Apply global pruning to the model with better memory management"""
    # Instead of deepcopy, work on the model directly

    # Collect weights for threshold calculation
    all_weights = []
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            all_weights.append(module.weight.data.abs().view(-1))

    # Calculate global threshold
    all_weights = torch.cat(all_weights)
    k = int(len(all_weights) * pruning_threshold)
    threshold = torch.kthvalue(all_weights, k)[0]

    # Free the temporary tensors
    del all_weights
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Apply pruning using the global threshold
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Create and apply mask
            with torch.no_grad():  # Prevent storing gradient history
                mask = (module.weight.data.abs() > threshold).float()
                module.weight.data *= mask

    return model

In [6]:
# 1. Load BERT model
print("1. Loading BERT model...")
model = AutoModel.from_pretrained('bert-base-uncased')

# 2. Save original model
print("\n2. Saving original model...")
model.save_pretrained("original_model")

# 3. Print original model statistics
print("\n3. Original model statistics:")
print_model_stats(model)

1. Loading BERT model...

2. Saving original model...

3. Original model statistics:
Number of parameters: 109.48M
Size: 417.70 MB
Memory allocation: 2242.09 MB
Sparsity: 0.00%


#Objectives with pruning:
1. reduce number of parameters (qty)
2. Increase sparsity (%)

#Non-objectives
1. reduce memory allocation (MB)
2. size (MB)


In [10]:
# 4. Apply pruning
pruning_threshold = 0.3  # This can be modified (e.g., 0.2 for 20%, 0.8 for 80%)
print(f"\n4. Applying global pruning with threshold {pruning_threshold}...")
pruned_model = global_pruning(model, pruning_threshold)

# 5. Save pruned model
print("\n5. Saving pruned model...")
pruned_model.save_pretrained("pruned_model")

# 6. Print pruned model statistics
print("\n6. Pruned model statistics:")
print_model_stats(pruned_model)


4. Applying global pruning with threshold 0.3...

5. Saving pruned model...

6. Pruned model statistics:
Number of parameters: 75.27M
Size: 417.70 MB
Memory allocation: 2158.53 MB
Sparsity: 31.25%
