<a href="https://colab.research.google.com/github/osu-mp/ai-539-nlp-group/blob/main/SparsityCheck.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade pip
!pip install torch
!pip install fairseq
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-23.0.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 22.0.4
    Uninstalling pip-22.0.4:
      Successfully uninstalled pip-22.0.4
Successfully installed pip-23.0.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fairseq
  Downloading fairseq-0.12.2.tar.gz (9.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25h

In [2]:
import torch
import torch.nn as nn
import os
from transformers import BartForConditionalGeneration

In [3]:
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive')
import sys
from google.colab import drive
sys.path.append(GOOGLE_DRIVE_PATH)
drive.mount('/content/drive/')

Mounted at /content/drive/


In [33]:
def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements



In [63]:
import copy
import torch.nn.utils.prune as prune
import copy
def get_weight_parameters(layer):
    '''
    Get all parameters/modules identified as 'weight'
    '''
    weight_parameters = []
    if len(list(layer.children())) > 0:
        for child in layer.children():
            for param in child.named_parameters():
                if 'weight' == param[0]:
                    weight_parameters.append((child, param[0]))
            weight_parameters.extend(get_weight_parameters(child))
    
    
    return weight_parameters


def prune_weight_parameters(model, prune_amount):
    '''
    Global pruning
    '''
    params_to_prune = get_weight_parameters(model)
  
    prune.global_unstructured(
        params_to_prune, 
        pruning_method=prune.L1Unstructured, 
        amount=prune_amount,
    )

    for module, name in params_to_prune:
        try:
            prune.remove(module, name)
            #print(module)
        except Exception as e:
            print(e)
    return model

def get_pruned_models(model, sparsity):
    model_to_prune = copy.deepcopy(model)
    pruned_model = prune_weight_parameters(model_to_prune, sparsity)
    return pruned_model

In [40]:
# Load the model in fairseq
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained('drive/MyDrive/AI539MUSSLT/bart.base', checkpoint_file='checkpoint_best.pt')
#bart.eval()  # disable dropout (or leave in train mode to finetune)

In [None]:
for module_name, module in bart.named_modules():
  print(module_name)

In [41]:
bart.fill_mask(['The cat <mask> on the <mask> .'], topk=3, beam=10)

[[('The cat is a cat on the', tensor(-2.4423)),
  ('The cat was on the catwalk', tensor(-2.4441)),
  ("The cat was on the cat's", tensor(-2.4598))]]

In [42]:
get_model_sparsity(bart.model)

9.080493007851409e-06

In [43]:
bart_total_params = sum(p.numel() for p in bart.parameters() if p.requires_grad)
bart_total_params

181928448

In [64]:

bart_pruned = get_pruned_models(bart.model, 0.1)
print(get_model_sparsity(bart_pruned))

Parameter 'weight' of module Embedding(50264, 768, padding_idx=1) has to be pruned before pruning can be removed
0.08258526451014414


In [60]:
get_model_sparsity(bart_pruned)



0.05253019582731777