In [3]:
import torch
import os
import random
import matplotlib.pyplot as plt
import pandas as pd
from model_config import ModelConfig
from pruning_methods.wanda import wanda_pruning
from pruning_methods.magnitude import magnitude_pruning
from evaluation_pruning import global_evaluation, generate_text, count_parameters, calculate_ecological_impact
from data_loading import get_wikitext2, get_wikitext2_unstructured
from dotenv import load_dotenv
from datasets import load_dataset
from plot_functions import plot_metrics, compare_prompt, compare_ecological_impact, plot_metrics_horizontal


### Settings

In [None]:
load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
llama_model = "meta-llama/Llama-3.2-1B"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

def execute_benchmark(model_config, pruning_function):
    trainloader , testloader = get_wikitext2_unstructured(model_config.nsamples, model_config.seed, model_config.seqlen, model_config.tokenizer)

    results = []
    for ratio in ratios:
        print(f"Pruning ratio: {ratio}")
        tmp_model_config = model_config.copy_model()
        if ratio != 0:
            if pruning_function == "magnitude":
                magnitude_pruning(tmp_model_config.model, ratio)
            elif pruning_function == "wanda":
                model_config.sparsity_ratio = ratio
                wanda_pruning(tmp_model_config.model)
        
        result_eval = global_evaluation(tmp_model_config, ratio, trainloader=trainloader, testloader=testloader, is_structured=False, device=device)
        results.append(result_eval)
        
        print(count_parameters(tmp_model_config.model))

    return results


def display_unstructured_results(results):
    ecological_impact = compare_ecological_impact(results, ratios)
    display(ecological_impact)

    prompt = compare_prompt(results, ratios)
    display(prompt)

    # Extract data for plotting
    perplexity = [result["perplexity"] for result in results]
    model_size = [result["model_size"] for result in results]
    plot_metrics_horizontal(ratios, perplexity, None, model_size)



### Magnitude Pruning - Facebook/OPT-350M

In [10]:
facebook_model_config = ModelConfig(model_name="facebook/opt-350m")
facebook_model_config.load_llm()

results = execute_benchmark(facebook_model_config, "magnitude_pruning")
display_unstructured_results(results)

Loading model 'facebook/opt-350m' from cache directory '.my_cache/llm_weights/'...




### Magnitude Pruning - meta-llama/Llama-3.2-1B

In [None]:
llama_model_config = ModelConfig(model_name=llama_model, token=token)
llama_model_config.load_llm()

results = execute_benchmark(llama_model_config, "magnitude_pruning")
display_unstructured_results(results)

### Wanda Pruning - Facebook/OPT-350M

In [None]:
facebook_model_config = ModelConfig(model_name="facebook/opt-350m")
facebook_model_config.load_llm()

results = execute_benchmark(facebook_model_config, "magnitude_pruning")
display_unstructured_results(results)

### Wanda Pruning - meta-llama/Llama-3.2-1B

In [None]:
llama_model_config = ModelConfig(model_name=llama_model, token=token)
llama_model_config.load_llm()

results = execute_benchmark(llama_model_config, "wanda_pruning")
display_unstructured_results(results)