In [None]:
import torch
import numpy as np
import pandas as pd
from torch.nn.utils import prune

In [None]:
def get_weight_parameters(layer):
    '''
    Get all parameters/modules identified as 'weight'
    '''
    weight_parameters = []
    if len(list(layer.children())) > 0:
        for child in layer.children():
            for param in child.named_parameters():
                if 'weight' == param[0]:
                    # print(param)
                    weight_parameters.append((child, param[0]))
            weight_parameters.extend(get_weight_parameters(child))
    
    
    return weight_parameters


def prune_weight_parameters(model, prune_amount):
    '''
    Global pruning
    '''
    weight_parameters = get_weight_parameters(model)

    prune.global_unstructured(
        weight_parameters, 
        pruning_method=prune.L1Unstructured, 
        amount=prune_amount,
    )

    for module, name in weight_parameters:
        prune.remove(module, name)
    return model

## Sparsifying GPT-2

In [None]:
from transformers import GPT2Model
gpt2_model = GPT2Model.from_pretrained("gpt2")