<a href="https://colab.research.google.com/github/pierredantas/LLMCompress/blob/main/Copy_of___Function_Pruning_Quantize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
from transformers import BertModel, BertConfig
import copy

Function that loads a neural network model and creates a copy for processing.
- Inputs:
    - model_name (str): Path or name of model to load. Default is 'bert-base-uncased'
    - model_type (str): Type of model to load ('bert' or 'pytorch'). Default is 'bert'
- Outputs:
    - original_model: The loaded model instance
    - model_copy: Deep copy of the loaded model

The function:
1. Loads model based on specified type:
    - BERT models using BertConfig and BertModel
    - PyTorch models using torch.load
2. Creates deep copy of loaded model
3. Returns both original and copied models

Model loading adapts to different model architectures while preserving the original model through copying.

In [25]:
def load_model(model_name='bert-base-uncased', model_type='bert'):

    # Load model based on type
    if model_type.lower() == 'bert':
        config = BertConfig.from_pretrained(model_name)
        base_model = BertModel(config)
    elif model_type.lower() == 'pytorch':
        base_model = torch.load(model_name)
    else:
        raise ValueError(f"Model type {model_type} not supported")

    # Create a deep copy of the model
    base_model_copy = copy.deepcopy(base_model)

    return base_model, base_model_copy

Function that calculates the percentage of zero-valued parameters in a neural network model.
- Inputs:
    - model: Any neural network model to analyze
- Outputs:
    - global_sparsity (float): Percentage of zero-valued weights in the model

The function:
1. Initializes counters for zeros and total elements
2. Iterates through all named weight parameters in the model
3. For each weight parameter:
    - Counts number of zero values
    - Counts total number of elements
    - Accumulates these counts
4. Calculates and returns global sparsity percentage (zeros/total * 100)

Model sparsity represents the proportion of weight parameters that are exactly zero, indicating how much the model has been pruned or sparsified.

In [26]:
def calculate_sparsity(model):
    # Initialize counters
    total_zeros = 0
    total_elements = 0

    # Calculate for all model parameters
    print("\nLayer-wise sparsity:")
    print("-" * 50)

    for name, param in model.named_parameters():
        if 'weight' in name:  # Only count weight parameters
            # Convert boolean comparison directly to float tensor
            zeros = (param == 0).float().sum().item()
            elements = param.nelement()
            layer_sparsity = 100. * zeros / elements
            print(f"Sparsity in {name}: {layer_sparsity:.2f}%")

            # Accumulate for global sparsity
            total_zeros += zeros
            total_elements += elements

    # Calculate and return global sparsity
    global_sparsity = 100. * total_zeros / total_elements

    return global_sparsity

Function that loads a saved neural network model and generates comprehensive statistics about its parameters and structure.
- Inputs:
    - model_path (str): Path to the saved model file
    - model_type (str): Type of model to load ('bert' or 'pytorch')
    - model_name (str, optional): Name identifier for the model. Default is "Model"
    - save_stats_path (str, optional): Path to save statistics to file. Default is None
- Outputs:
    - None (prints statistics and optionally saves to file)

The function:
1. Loads the model:
    - Handles different model types (BERT or PyTorch)
    - Validates model type support
2. Calculates model metrics:
    - Counts total and trainable parameters
    - Computes model size in MB
    - Determines model sparsity
3. Generates statistics report:
    - Model identification information
    - Parameter counts and size
    - Sparsity measurements
4. Outputs results:
    - Prints statistics to console
    - Optionally saves to specified file path

Model statistics provide insights into model complexity, memory usage, and compression through sparsity.

In [27]:
def print_model_stats(model, model_name="Model"):

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Calculate non-zero parameters
    nonzero_params = sum(torch.count_nonzero(p) for name, p in model.named_parameters() if 'weight' in name)

    # Calculate model size considering element size of parameters
    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)

    # Calculate sparsity
    sparsity = calculate_sparsity(model)

    # Create statistics string
    stats = f"=== {model_name} Statistics ===\n"
    stats += f"Total parameters: {total_params:,}\n"
    stats += f"Trainable parameters: {trainable_params:,}\n"
    stats += f"Non-zero parameters: {nonzero_params:,}\n"
    stats += f"Model size: {model_size:.2f} MB\n"
    stats += f"Model sparsity: {sparsity:.2f}%\n"
    stats += "="*25

    # Print statistics
    print(stats)

Function that applies global unstructured pruning to a PyTorch model using L1 norm and saves the pruned model.
- Inputs:
    - model: PyTorch model to be pruned
    - amount (float, optional): Fraction of parameters to prune, range 0 to 1.
    - save_path (str, optional): Directory path where pruned model will be saved. Default is 'pruned_model'
- Outputs:
    - model: The pruned PyTorch model
    - save_path: Path where the pruned model was saved

The function:
1. Gets prunable parameters from model using get_prunable_parameters()
2. Applies global unstructured L1 pruning with specified amount
3. Creates directory and saves pruned model to disk
4. Returns pruned model and save location

Global unstructured pruning removes weights based on their L1 norm magnitude across the entire network.

In [28]:
def prune_model(model, amount):

    # Get parameters that can be pruned
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            parameters_to_prune.append((module, 'weight'))

    # Apply global unstructured pruning using the simpler version
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,  # Use the class directly
        amount=amount
    )

    # Make the pruning permanent
    for module, name in parameters_to_prune:
        prune.remove(module, 'weight')

    return model

Convert to float16 (quantization)

In [29]:
def convert_to_float16(model):

    try:

        # Create a copy of the model first
        model_fp16 = copy.deepcopy(model)

        # Convert model to Float16
        model_fp16 = model_fp16.half()

        return model_fp16

    except Exception as e:
        print(f"Error during Float16 conversion: {str(e)}")
        return None, None

Load BERT model and show the statistics

In [30]:
# Load the BERT model
print("Loading BERT model...")
original_model, original_model_copy = load_model(model_name='bert-base-uncased', model_type='bert')
print("Model loaded successfully!")

# Print model statistics
print_model_stats(
    model=original_model,
    model_name="BERT Base Uncased",
)

Loading BERT model...
Model loaded successfully!

Layer-wise sparsity:
--------------------------------------------------
Sparsity in embeddings.word_embeddings.weight: 0.00%
Sparsity in embeddings.position_embeddings.weight: 0.00%
Sparsity in embeddings.token_type_embeddings.weight: 0.00%
Sparsity in embeddings.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.attention.self.query.weight: 0.00%
Sparsity in encoder.layer.0.attention.self.key.weight: 0.00%
Sparsity in encoder.layer.0.attention.self.value.weight: 0.00%
Sparsity in encoder.layer.0.attention.output.dense.weight: 0.00%
Sparsity in encoder.layer.0.attention.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.intermediate.dense.weight: 0.00%
Sparsity in encoder.layer.0.output.dense.weight: 0.00%
Sparsity in encoder.layer.0.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.1.attention.self.query.weight: 0.00%
Sparsity in encoder.layer.1.attention.self.key.weight: 0.00%
Sparsity in encoder.layer.1.attention.self

Let me explain the implications of pruning different types of layers in BERT:

- Embedding Layers:
  - Risky to prune extensively as they map vocabulary tokens to dense representations
  - Heavy pruning could severely impact the model's ability to understand word meanings
  - Each zero in embedding means a word loses some of its semantic features

- LayerNorm Parameters:
  - These are crucial for stabilizing network training
  - Very few parameters compared to other layers
  - Pruning these could destabilize the entire network
  - Generally not recommended to prune normalization layers

- Bias Terms:
  - Biases add important offsets to each neuron's activation
  - Relatively few parameters compared to weights
  - Pruning biases can significantly impact model performance
  - Usually kept intact as their memory footprint is small


- Other Layer Types:
  - Attention layers: Pruning could damage the model's ability to focus on relevant parts of input
  - Position embeddings: Pruning would hurt the model's understanding of word order
  - Intermediate layers: Could be pruned but might affect complex feature representations

- That's why the common practice is to:
  - Focus pruning on Linear/Conv2d layers as they have the most redundant parameters
  - Leave embedding, normalization, and bias terms intact
  - Preserve the model's fundamental abilities while reducing size

In [31]:
# Prune the model
print("\nPruning model...")
pruning_amount = 0.2
pruned_model = prune_model(model=original_model, amount=pruning_amount)
print("Model pruned successfully!")

# Print model statistics
print("\nGenerating pruned model statistics...")
print_model_stats(
    model=pruned_model,
    model_name="Pruned BERT",
)


Pruning model...
Model pruned successfully!

Generating pruned model statistics...

Layer-wise sparsity:
--------------------------------------------------
Sparsity in embeddings.word_embeddings.weight: 0.00%
Sparsity in embeddings.position_embeddings.weight: 0.00%
Sparsity in embeddings.token_type_embeddings.weight: 0.00%
Sparsity in embeddings.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.attention.self.query.weight: 19.98%
Sparsity in encoder.layer.0.attention.self.key.weight: 19.99%
Sparsity in encoder.layer.0.attention.self.value.weight: 19.94%
Sparsity in encoder.layer.0.attention.output.dense.weight: 20.05%
Sparsity in encoder.layer.0.attention.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.intermediate.dense.weight: 20.01%
Sparsity in encoder.layer.0.output.dense.weight: 20.01%
Sparsity in encoder.layer.0.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.1.attention.self.query.weight: 19.95%
Sparsity in encoder.layer.1.attention.self.key.weight: 20.06%

In [32]:
# Quantize the model to float16
print("\nQuantizing model...")
fp16_model = convert_to_float16(pruned_model)
print("Model quantized float16 successfully!")

# Print model statistics
print("\nGenerating float16 quantized model statistics...")
print_model_stats(
    model=fp16_model,
    model_name="Floa16 quantized BERT",
)


Quantizing model...
Model quantized float16 successfully!

Generating float16 quantized model statistics...

Layer-wise sparsity:
--------------------------------------------------
Sparsity in embeddings.word_embeddings.weight: 0.00%
Sparsity in embeddings.position_embeddings.weight: 0.00%
Sparsity in embeddings.token_type_embeddings.weight: 0.00%
Sparsity in embeddings.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.attention.self.query.weight: 19.98%
Sparsity in encoder.layer.0.attention.self.key.weight: 19.99%
Sparsity in encoder.layer.0.attention.self.value.weight: 19.94%
Sparsity in encoder.layer.0.attention.output.dense.weight: 20.05%
Sparsity in encoder.layer.0.attention.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.0.intermediate.dense.weight: 20.01%
Sparsity in encoder.layer.0.output.dense.weight: 20.01%
Sparsity in encoder.layer.0.output.LayerNorm.weight: 0.00%
Sparsity in encoder.layer.1.attention.self.query.weight: 19.95%
Sparsity in encoder.layer.1.attentio