## 

In [3]:
from model_config import ModelConfig
from pruning_methods.wanda import wanda_pruning
from pruning_methods.magnitude import magnitude_pruning
import torch

In [4]:
import os
from dotenv import load_dotenv

load_dotenv()

token = os.getenv("HUGGINGFACE_TOKEN")

llama_model = "meta-llama/Llama-3.2-1B"
modelConfig = ModelConfig(token=token)
model = modelConfig.load_llm()

Loading model 'facebook/opt-350m' from cache directory '.cache/llm_weights/'...


## Magnitude Pruning

In [5]:
def count_parameters(model):
    """
    Count the total number of non-zero parameters in a model.
    
    Args:
        model (torch.nn.Module): The model to count parameters for
    
    Returns:
        tuple: (total non-zero parameters, trainable non-zero parameters)
    """
    total_nonzero_params = 0
    trainable_nonzero_params = 0
    
    for param in model.parameters():
        num_nonzero_params = torch.count_nonzero(param).item()  # Count non-zero elements
        total_nonzero_params += num_nonzero_params
        if param.requires_grad:
            trainable_nonzero_params += num_nonzero_params
    
    return total_nonzero_params, trainable_nonzero_params


In [6]:
original_model = modelConfig.model

prunned_model = modelConfig.copy_model()

pruning_result = magnitude_pruning(prunned_model, 0.5)

print(f"number of parameters in original model: {count_parameters(original_model)}")
print(f"number of parameters in prunned model: {count_parameters(prunned_model)}")

number of parameters in original model: (331195120, 331195120)
number of parameters in prunned model: (166761506, 166761506)


In [None]:
original_output = original_model()[0][
        "generated_text"
    ]

AttributeError: 'str' object has no attribute 'size'

In [7]:
from evaluation_pruning import global_evaluation

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
original_model.to(device)
prunned_model.to(device)

global_evaluation(modelConfig, original_model, prunned_model, modelConfig.tokenizer, device=device)

Wikitext Perplexity: 100%|██████████| 20/20 [01:27<00:00,  4.35s/it]
Wikitext Perplexity: 100%|██████████| 20/20 [01:26<00:00,  4.34s/it]


Original Model Perplexity:  23.599618911743164
Pruned Model Perplexity:  1771.341064453125
Model Memory Difference:  {'Pruned Model Size (bytes)': 662538026, 'Original Model Size (bytes)': 662538026, 'Space Saved (bytes)': 0, 'Percentage Saved (%)': 0.0}


TypeError: OPTForCausalLM.forward() got an unexpected keyword argument 'max_length'

## Wanda Pruning

In [None]:
# wanda_pruning(modelConfig)