# Lesson 3, Exercise 2: Strategic Pruning on Llama-3.2-1B - Impact of Method and Target

**Goal:**
The objective of this exercise is to explore the nuanced impact of different pruning strategies when applied to a large and powerful language model like Llama-3.2-1B. You will investigate how the choice of pruning *method* (e.g., magnitude-based vs. random) and pruning *target* (e.g., MLP layers vs. attention layers) affects the model's output quality and inference speed at similar sparsity levels, even before any fine-tuning.

## 2. Imports and Configuration

In [None]:
import os
import torch
import torch.nn.utils.prune as prune
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import copy # For creating fresh model copies
import pandas as pd

os.environ["HF_HUB_OFFLINE"] = "1"
MODEL_NAME = "/voc/shared/models/llama/Llama-3.2-1B"

PRUNING_AMOUNT = 0.3 # Target ~30% sparsity in selected layers
NUM_LAYERS_TO_TARGET_EXAMPLE = 2 # Target first N layers for simplicity in this example

PROMPT_TEXT = "The future of AI is"
MAX_NEW_TOKENS_PRUNING = 30 
NUM_SPEED_RUNS = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
print(f"Using device: {device}, Model dtype: {model_dtype}")

# ### TODO: Define Llama-3.2-1B Layer Names
# Based on inspecting the model architecture (print(model) after loading once).
LLAMA_MLP_GATE_PROJ_TARGETS = [f"model.layers.{i}.mlp.gate_proj" for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]
LLAMA_MLP_UP_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]
LLAMA_MLP_DOWN_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]

LLAMA_ATTN_Q_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]
LLAMA_ATTN_K_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]
LLAMA_ATTN_V_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]
LLAMA_ATTN_O_PROJ_TARGETS = [____ for i in range(NUM_LAYERS_TO_TARGET_EXAMPLE)]

## 3. Helper Functions

In [None]:
def get_module_by_name(model, module_name_str):
    """Gets a module from a model using its string name."""
    ### TODO: Implement this helper function
    # Split module_name_str by '.' and recursively use getattr.
    # Example: model.layers[0].mlp.gate_proj -> model.layers.0.mlp.gate_proj for getattr
    # Be careful with list indexing if layers are in a ModuleList
    names = module_name_str.split('.')
    module = model
    for name_part in names:
            module = None # TODO: use getattr to fetch the relevant module part
    return module

def calculate_sparsity(module):
    """Calculates sparsity of a module's weight if it exists."""
    if hasattr(module, 'weight') and module.weight is not None:
        fraction = None # TODO: fraction of number of weights that are zero over total number of weights
        return fraction
    return 0.0

def apply_pruning_to_layers(model, layer_names_list, amount, method='l1_unstructured'):
    """Applies global unstructured pruning to a list of specified layers."""
    parameters_to_prune_tuples = []
    valid_layer_names_pruned = []
    for name_str in layer_names_list:
        try:
            ### TODO: Get the module using get_module_by_name
            module = None # Placeholder

            if module and hasattr(module, 'weight') and module.weight is not None:
                 parameters_to_prune_tuples.append((module, 'weight'))
                 valid_layer_names_pruned.append(name_str)
            else:
                print(f"Warning: Layer {name_str} has no 'weight' or weight is None. Skipping.")
        except AttributeError:
            print(f"Warning: Layer {name_str} not found in model. Make sure names are correct. Skipping.")

    if not parameters_to_prune_tuples:
        print("No valid parameters found to prune for the given layer names.")
        return [] # Return empty list if no layers were pruned

    ### TODO: Apply global_unstructured pruning using torch.nn.utils.prune
    # Use prune.L1Unstructured or prune.RandomUnstructured based on 'method'.
    # Example: 
    # if method == 'l1_unstructured':
    #     pruning_method_class = prune.L1Unstructured
    # elif method == 'random_unstructured':
    #     pruning_method_class = prune.RandomUnstructured
    # else: raise ValueError("Unsupported pruning method")
    pass # Replace with pruning application
    
    ### TODO: Make pruning permanent for all pruned parameters
    # Iterate through parameters_to_prune_tuples and use prune.remove().
    pass # Replace with permanent pruning

    print(f"Applied {method} pruning (amount {amount*100:.1f}%) to {len(valid_layer_names_pruned)} layers.")
    return valid_layer_names_pruned # Return names of layers actually processed

def measure_generation_speed_and_quality(model, tokenizer, prompt, max_new_tokens, num_runs):
    total_time = 0
    generated_text_sample = "Error: Generation did not run."
    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device if hasattr(model, 'device') and model.device is not None else device)
    
    with torch.no_grad():
        for i in range(num_runs):
            current_device = model.device if hasattr(model, 'device') and model.device is not None else device
            if current_device.type == 'cuda':
                torch.cuda.synchronize(current_device) 
            start_time = time.perf_counter()

            ### TODO: Generate text using model.generate()
            # Use do_sample=False for consistent speed testing.
            # outputs = model.generate(...)
            outputs = None # Placeholder

            if current_device.type == 'cuda':
                torch.cuda.synchronize(current_device)
            end_time = time.perf_counter()

            if i == 0 and outputs is not None:
                generated_text_sample = tokenizer.decode(outputs[0], skip_special_tokens=True)
            if outputs is not None:
                 total_time += (end_time - start_time)
            else:
                total_time = float('inf') # Indicate error
                break

    avg_time = total_time / num_runs if num_runs > 0 and total_time != float('inf') else float('nan')
    return avg_time, generated_text_sample

## 4. Load Original Model and Tokenizer (Baseline)

In [None]:
results_summary_list = []
original_model = None
tokenizer = None

print(f"Loading original model: {MODEL_NAME}...")
try:
    ### TODO: Load the Llama-3.2-1B model using AutoModelForCausalLM
    ### TODO: Load the tokenizer for the model
    original_model, tokenizer = None, None
    
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    if original_model is None or tokenizer is None:
        raise ValueError("Original model or tokenizer not loaded. Check TODOs.")
    
    print("Original model and tokenizer loaded.")
    # You can print model architecture here to verify layer names if needed: print(original_model)

    ### TODO: Measure generation speed and get output for the original model using measure_generation_speed_and_quality
    original_avg_time, original_output = float('nan'), "N/A (TODO)"

    results_summary_list.append({
        "Configuration": "Original",
        "Avg Sparsity Targeted (%)": 0,
        "Avg Inference Time (s)": f"{original_avg_time:.4f}",
        "Speed-up vs Original": "1.00x",
        "Output Sample": original_output
    })
    print(f"Original Model: Avg Time={original_avg_time:.4f}s, Output='{original_output[:100]}...'\n")

except Exception as e:
    print(f"CRITICAL ERROR loading original model {MODEL_NAME}: {e}")
    print("Ensure correct MODEL_NAME, HF_TOKEN (if needed), and sufficient resources.")
    original_model = None # Prevent further execution if base model fails

## 5. Pruning Experiments

Helper function to run a single pruning experiment configuration:

In [None]:
def run_single_pruning_experiment(config_name, base_model_to_copy, target_layer_names, pruning_amt, prune_method):
    if base_model_to_copy is None: # Guard if original model loading failed
        print(f"Skipping {config_name} as base model is not available.")
        results_summary_list.append({"Configuration": config_name, "Error": "Base model not loaded"})
        return
        
    print(f"\n--- Running: {config_name} (Method: {prune_method}) ---")
    pruned_model_instance = None
    try:
        ### TODO: Create a deepcopy of the base_model_to_copy for this experiment
        # This is to ensure each pruning experiment starts from a fresh, unpruned state.
        pass # Replace with model copying

        if pruned_model_instance is None:
            raise ValueError("Pruned model instance not created.")

        ### TODO: Apply pruning to the target_layer_names using apply_pruning_to_layers helper
        processed_layers = [] # Placeholder

        avg_sparsity_achieved = 0.0
        if processed_layers: # Only calculate if some layers were actually pruned
            ### TODO: Calculate the average sparsity across the processed_layers. Use get_module_by_name and calculate_sparsity
            ###       helper functions
            pass
        
        print(f"Average sparsity in targeted layers: {avg_sparsity_achieved:.2f}%")

        ### TODO: Measure generation speed and quality for the pruned_model_instance using measure_generation_speed_and_quality 
        avg_time, output = float('nan'), "N/A (TODO)"
        
        speed_up_val = original_avg_time / avg_time if avg_time > 0 and not pd.isna(original_avg_time) else float('nan')

        results_summary_list.append({
            "Configuration": config_name,
            "Avg Sparsity Targeted (%)": f"{avg_sparsity_achieved:.2f}",
            "Avg Inference Time (s)": f"{avg_time:.4f}",
            "Speed-up vs Original": f"{speed_up_val:.2f}x",
            "Output Sample": output
        })
        print(f"{config_name}: Avg Time={avg_time:.4f}s, Speed-up={speed_up_val:.2f}x, Output='{output[:100]}...'\n")

    except Exception as e:
        print(f"Error during {config_name}: {e}")
        results_summary_list.append({"Configuration": config_name, "Error": str(e)})
    finally:
        del pruned_model_instance # Important to free memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

### 5.1 Part A: Pruning Method Comparison

In [None]:
if original_model is not None: # Proceed only if base model loaded successfully
    print("\n--- Part A: Pruning Method Comparison (Targeting MLP Gate Projections) ---")
    part_a_targets = LLAMA_MLP_GATE_PROJ_TARGETS # Example, choose a consistent set

    ### TODO: Run experiment for Magnitude Pruning on part_a_targets using run_single_pruning_experiment function

    ### TODO: Run experiment for Random Pruning on part_a_targets
else:
    print("Skipping Part A due to original model loading failure.")

### 5.2 Part B: Pruning Target Comparison

In [None]:
if original_model is not None: # Proceed only if base model loaded successfully
    print("\n\n--- Part B: Pruning Target Comparison (All using Magnitude Pruning) ---")
    
    ### TODO: Define the list of ALL MLP layers to target for Part B and run Magnitude Pruning experiment
    mlp_layers_for_b = None # TODO: combine various MLP layers. E.g. LLAMA_MLP_GATE_PROJ_TARGETS + LLAMA_MLP_UP_PROJ_TARGETS + etc

    ### TODO: Define the list of ALL Attention projection layers to target for Part B and run Magnitude Pruning experiment
    attn_layers_for_b = None # TODO: combine various ATTN layers
else:
    print("Skipping Part B due to original model loading failure.")

## 6. Display Final Results Summary

In [None]:
df_final_results = pd.DataFrame(results_summary_list)
print("\n\n--- Overall Pruning Experiment Results Summary ---")
print(df_final_results.to_string())

## 7. Analysis and Discussion

Based on the 'Overall Pruning Experiment Results Summary' table:

1.  **Part A - Pruning Method:**
    *   **TODO**: Compare the output quality and inference speed-up between magnitude pruning and random pruning for Llama-3.2-1B at similar sparsity levels. Did one perform better? Why might that be?

2.  **Part B - Pruning Target:**
    *   **TODO**: Compare the impact of pruning MLP layers versus attention projection layers on Llama-3.2-1B's output quality and speed. Which type of layer seemed more critical or sensitive to pruning before fine-tuning? Offer potential reasons.

3.  **Challenges with Llama-3.2-1B:**
    *   **TODO**: Reflect on any practical challenges encountered (e.g., memory usage, computation time, verifying layer names) while applying these pruning techniques to a model of Llama-3.2-1B's scale.

4.  **Inference Speed-up:**
    *   **TODO**: Was there a noticeable inference speed-up from unstructured pruning in your experiments? Discuss why or why not, considering the nature of unstructured pruning and standard hardware.

5.  **Necessity of Fine-tuning:**
    *   **TODO**: Emphasize the likely necessity of fine-tuning to make such pruned models practically useful, based on the quality of outputs you observed.