In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Network Pruning (Pytorch)

In [1]:
import torch
import torch.quantization
import torch.nn as nn

In [2]:
# set seed for reproducibility
torch.manual_seed(0)

class SampleLinearModel(nn.Module):
  def __init__(self):
    super(SampleLinearModel, self).__init__()
    self.linear1 = nn.Linear(10, 10)

  def forward(self, x):
    x = self.linear1(x)
    return x

In [3]:
# Original model
original_model = SampleLinearModel()
print(original_model)

SampleLinearModel(
  (linear1): Linear(in_features=10, out_features=10, bias=True)
)


In [4]:
for param_name , param in original_model.named_parameters():
  print(param_name)

linear1.weight
linear1.bias


In [5]:
# create pruned model
import torch.nn.utils.prune as prune

pruned_model = SampleLinearModel()
parameters_to_prune = (
    (pruned_model.linear1, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.5
)

In [6]:
pruned_model.linear1.weight_mask

tensor([[0., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 1., 0., 0., 0., 0., 1.],
        [1., 1., 1., 0., 1., 0., 1., 0., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 0., 1., 0.],
        [0., 1., 1., 1., 1., 1., 0., 1., 0., 0.]])

In [8]:
pruned_model.linear1._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.CustomFromMask at 0x7817e6c612d0>)])

## Network Pruning (Tensorflow)

In [13]:
import tensorflow_model_optimization as tfmot
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import numpy as np

# data
data = pd.read_csv("./sample_google_scholar.csv")
data = data.dropna()

def convert_first_ten_characters_into_tensor(data):
  first_ten_characters = data[:10]
  converted = [ord(char)/256 for char in first_ten_characters]
  while len(converted) < 10:
    converted.append(0.0)
  return np.array(converted)

converted_affiliation = data['affiliation'].map(convert_first_ten_characters_into_tensor)
affiliation = np.vstack(converted_affiliation.values)
converted_email = data['email'].str.contains('.edu')
labels = converted_email.values

# model
input_shape = 10

In [14]:
num_examples_train = len(affiliation)
batch_size = 16
epochs = 5

end_step = np.ceil(num_examples_train/batch_size).astype(np.int32)*epochs

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.3, final_sparsity=0.5, begin_step=0,
                                                             end_step=end_step)
}

In [15]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Dense(128, activation='relu', name='layer1'),
        layers.Dense(64, activation='relu', name='layer2'),
        layers.Dense(1, activation='sigmoid', name='layer3'),
    ]
)

loss = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [16]:
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

In [17]:
model_for_pruning.compile(loss=loss, optimizer=optimizer)

callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

model_for_pruning.fit(affiliation, labels, batch_size=6, epochs=5, validation_split=0.2, callbacks=callbacks)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x78178d5c7a90>

In [19]:
model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_layer1  (None, 128)               2690      
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_layer2  (None, 64)                16450     
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_layer3  (None, 1)                 131       
  (PruneLowMagnitude)                                            
                                                                 
Total params: 19271 (75.29 KB)
Trainable params: 9729 (38.00 KB)
Non-trainable params: 9542 (37.29 KB)
_________________________________________________________________


In [20]:
final_tf_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

In [21]:
final_tf_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 layer1 (Dense)              (None, 128)               1408      
                                                                 
 layer2 (Dense)              (None, 64)                8256      
                                                                 
 layer3 (Dense)              (None, 1)                 65        
                                                                 
Total params: 9729 (38.00 KB)
Trainable params: 9729 (38.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [22]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 layer1 (Dense)              (None, 128)               1408      
                                                                 
 layer2 (Dense)              (None, 64)                8256      
                                                                 
 layer3 (Dense)              (None, 1)                 65        
                                                                 
Total params: 9729 (38.00 KB)
Trainable params: 9729 (38.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
