In [3]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [4]:

from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import AdamW

In [5]:
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")


In [74]:
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")

# Assessing Sparsity and creating plots

### Flattening weights for each layer

In [None]:
i = 0
encoder = "model.encoder.layers."
decoder = "model.decoder.layers."

d = {}
for i in range(24):
  encoder_check = encoder+'{}.'.format(i)
  decoder_check = decoder+'{}.'.format(i)
  d[encoder_check] = []
  d[decoder_check] = []

KEYS = list(d.keys())
for name, param in model.named_parameters():
  a = (torch.flatten(param)).tolist()
  match_list = map(name.startswith, KEYS)
  for ind, elem in enumerate(match_list):
    if elem:
      match = KEYS[ind]
      a = (torch.flatten(param)).tolist()
      d[match].extend(a)
      break


### Plotting Percentages

In [None]:
import matplotlib.pyplot as plt

i = 0
for key, val in d.items():
    a = val
    to_plot_dict = {}
    to_plot_dict['-0.08'] = 0
    to_plot_dict['-0.065'] = 0
    to_plot_dict['-0.03'] = 0
    to_plot_dict['0'] = 0
    to_plot_dict['0.03'] = 0
    to_plot_dict['0.065'] = 0
    to_plot_dict['0.08'] = 0

    for elem in a:
        if elem <= -.08:
            to_plot_dict['-0.08'] += 1
        elif elem <= -.05 and elem > -.08:
            to_plot_dict['-0.065'] += 1
        elif elem <= -.01 and elem > -0.05:
            to_plot_dict['-0.03'] += 1
        elif elem <= .01 and elem > -0.01:
            to_plot_dict['0'] += 1
        elif elem <= .05 and elem > 0.01:
            to_plot_dict['0.03'] += 1
        elif elem <= .08 and elem > 0.05:
            to_plot_dict['0.065'] += 1
        elif elem >= 0.08:
            to_plot_dict['0.08'] += 1

    x = list(to_plot_dict.keys())
    counts = list(to_plot_dict.values())
    y = [ind/len(a) for ind in counts]
    plt.figure()
    plt.bar(x, y, color ='maroon',
            width = 0.1)
    plt.xlabel("avg value of weight")
    plt.ylabel("percentage")
    plt.title("Distribution of weights in layer {}".format(key))
    plt.show()


### Plotting densities

In [None]:
import numpy as np
import matplotlib.pyplot as plt
for key, val in d.items():
    plt.figure()
    plt.hist(val, density = True, bins = 1000, range= (-.25, .25))
    plt.gca().set(title='Frequency Histogram for {}'.format(key), ylabel='Frequency')
    plt.show()

# Trying to access layers to prune

In [75]:
#https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.l1_unstructured.html#torch.nn.utils.prune.l1_unstructured
def prune_layers(ratio):
    for i in range(24):
        list_encoder = ['model.encoder.layers.{}.self_attn.k_proj'.format(i),
                        'model.encoder.layers.{}.self_attn.v_proj'.format(i),
                        'model.encoder.layers.{}.self_attn.q_proj'.format(i),
                        'model.encoder.layers.{}.self_attn.out_proj'.format(i),
                        'model.encoder.layers.{}.self_attn_layer_norm'.format(i),
                        'model.encoder.layers.{}.fc1'.format(i),
                        'model.encoder.layers.{}.fc2'.format(i),
                        'model.encoder.layers.{}.final_layer_norm'.format(i)]
        
        list_decoder = ['model.decoder.layers.{}.self_attn.k_proj'.format(i),
                        'model.decoder.layers.{}.self_attn.v_proj'.format(i),
                        'model.decoder.layers.{}.self_attn.q_proj'.format(i),
                        'model.decoder.layers.{}.self_attn.out_proj'.format(i),
                        'model.decoder.layers.{}.self_attn_layer_norm'.format(i),
                        'model.decoder.layers.{}.fc1'.format(i),
                        'model.decoder.layers.{}.fc2'.format(i),
                        'model.decoder.layers.{}.final_layer_norm'.format(i)]

        for name, module in model.named_modules():
            if name in list_decoder or name in list_encoder:
                prune.l1_unstructured(module, name='weight', amount=ratio)
                prune.remove(module, 'weight')


In [None]:
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
ratio = 0.2
prune_layers(ratio)