In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, Conv3D, DepthwiseConv2D, SeparableConv2D, Conv3DTranspose
from keras.layers import Flatten, MaxPool2D, AvgPool2D, GlobalAvgPool2D, UpSampling2D, BatchNormalization
from keras.layers import Concatenate, Add, Dropout, ReLU, Lambda, Activation, LeakyReLU, PReLU
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

from time import time
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+--------------+------------+
|   Modules    | Parameters |
+--------------+------------+
| conv1.weight |     54     |
|  conv1.bias  |     6      |
| conv2.weight |    864     |
|  conv2.bias  |     16     |
|  fc1.weight  |   48000    |
|   fc1.bias   |    120     |
|  fc2.weight  |   10080    |
|   fc2.bias   |     84     |
|  fc3.weight  |    840     |
|   fc3.bias   |     10     |
+--------------+------------+
Total Trainable Params: 60074


60074

In [None]:
PATH = './base_model.pth'
torch.save(model.state_dict(), PATH)

In [None]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.2289,  0.1756, -0.3209],
          [-0.1498, -0.1119,  0.1736],
          [ 0.0037,  0.0153, -0.0493]]],


        [[[-0.0493,  0.0318, -0.1283],
          [ 0.3133,  0.1460,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.3139, -0.1866, -0.2675],
          [ 0.0812, -0.0749,  0.1801],
          [-0.2216,  0.0712, -0.1755]]],


        [[[-0.0025, -0.3208,  0.2431],
          [-0.0657, -0.1588,  0.3007],
          [-0.0029, -0.0642, -0.2951]]],


        [[[ 0.0536, -0.1346, -0.0144],
          [ 0.0269, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0380],
          [ 0.1973,  0.0386,  0.1585]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0885, -0.2712,  0.2385,  0.2115,  0.2645, -0.1051],
       requires_grad=True))]


In [None]:
print(list(module.named_buffers()))

[]


In [None]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [None]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.0885, -0.2712,  0.2385,  0.2115,  0.2645, -0.1051],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2289,  0.1756, -0.3209],
          [-0.1498, -0.1119,  0.1736],
          [ 0.0037,  0.0153, -0.0493]]],


        [[[-0.0493,  0.0318, -0.1283],
          [ 0.3133,  0.1460,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.3139, -0.1866, -0.2675],
          [ 0.0812, -0.0749,  0.1801],
          [-0.2216,  0.0712, -0.1755]]],


        [[[-0.0025, -0.3208,  0.2431],
          [-0.0657, -0.1588,  0.3007],
          [-0.0029, -0.0642, -0.2951]]],


        [[[ 0.0536, -0.1346, -0.0144],
          [ 0.0269, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0380],
          [ 0.1973,  0.0386,  0.1585]]]], requires_grad=True))]


In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [0., 1., 1.]]]]))]


In [None]:
print(module.weight)


tensor([[[[-0.0000,  0.1756, -0.3209],
          [-0.0000, -0.1119,  0.1736],
          [ 0.0037,  0.0000, -0.0493]]],


        [[[-0.0493,  0.0000, -0.0000],
          [ 0.3133,  0.0000,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.3139, -0.1866, -0.0000],
          [ 0.0000, -0.0000,  0.1801],
          [-0.2216,  0.0000, -0.1755]]],


        [[[-0.0025, -0.0000,  0.2431],
          [-0.0657, -0.1588,  0.0000],
          [-0.0029, -0.0642, -0.2951]]],


        [[[ 0.0536, -0.0000, -0.0144],
          [ 0.0000, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0000],
          [ 0.0000,  0.0386,  0.1585]]]], grad_fn=<MulBackward0>)


In [None]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f6b0cdb2a90>)])


In [None]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [None]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2289,  0.1756, -0.3209],
          [-0.1498, -0.1119,  0.1736],
          [ 0.0037,  0.0153, -0.0493]]],


        [[[-0.0493,  0.0318, -0.1283],
          [ 0.3133,  0.1460,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.3139, -0.1866, -0.2675],
          [ 0.0812, -0.0749,  0.1801],
          [-0.2216,  0.0712, -0.1755]]],


        [[[-0.0025, -0.3208,  0.2431],
          [-0.0657, -0.1588,  0.3007],
          [-0.0029, -0.0642, -0.2951]]],


        [[[ 0.0536, -0.1346, -0.0144],
          [ 0.0269, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0380],
          [ 0.1973,  0.0386,  0.1585]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0885, -0.2712,  0.2385,  0.2115,  0.2645, -0.1051],
       requires_grad=True))]


In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [0., 1., 1.]]]])), ('bias_mask', tensor([0., 1., 1., 0., 1., 0.]))]


In [None]:
print(module.bias)

tensor([-0.0000, -0.2712,  0.2385,  0.0000,  0.2645, -0.0000],
       grad_fn=<MulBackward0>)


In [None]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f6b0cdb2a90>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f6b468f4588>)])


In [None]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
print(module.weight)

tensor([[[[-0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0493,  0.0000, -0.0000],
          [ 0.3133,  0.0000,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0536, -0.0000, -0.0144],
          [ 0.0000, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0000],
          [ 0.0000,  0.0386,  0.1585]]]], grad_fn=<MulBackward0>)


In [None]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f6b0cdb2a90>, <torch.nn.utils.prune.LnStructured object at 0x7f6b0cdb2e48>]


In [None]:
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [None]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2289,  0.1756, -0.3209],
          [-0.1498, -0.1119,  0.1736],
          [ 0.0037,  0.0153, -0.0493]]],


        [[[-0.0493,  0.0318, -0.1283],
          [ 0.3133,  0.1460,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.3139, -0.1866, -0.2675],
          [ 0.0812, -0.0749,  0.1801],
          [-0.2216,  0.0712, -0.1755]]],


        [[[-0.0025, -0.3208,  0.2431],
          [-0.0657, -0.1588,  0.3007],
          [-0.0029, -0.0642, -0.2951]]],


        [[[ 0.0536, -0.1346, -0.0144],
          [ 0.0269, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0380],
          [ 0.1973,  0.0386,  0.1585]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0885, -0.2712,  0.2385,  0.2115,  0.2645, -0.1051],
       requires_grad=True))]


In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [0., 1., 1.]]]])), ('bias_mask', tensor([0., 1., 1., 0., 1., 0.]))]


In [None]:
print(module.weight)

tensor([[[[-0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0493,  0.0000, -0.0000],
          [ 0.3133,  0.0000,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0536, -0.0000, -0.0144],
          [ 0.0000, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0000],
          [ 0.0000,  0.0386,  0.1585]]]], grad_fn=<MulBackward0>)


In [None]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.0885, -0.2712,  0.2385,  0.2115,  0.2645, -0.1051],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0493,  0.0000, -0.0000],
          [ 0.3133,  0.0000,  0.1923],
          [-0.3069,  0.0470,  0.1445]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0536, -0.0000, -0.0144],
          [ 0.0000, -0.2241, -0.2860],
          [ 0.0715,  0.2916,  0.2052]]],


        [[[-0.2662, -0.2328,  0.1503],
          [-0.3139, -0.2856, -0.0000],
          [ 0.0000,  0.0386,  0.1585]]]], requires_grad=True))]


In [None]:
print(list(module.named_buffers()))

[('bias_mask', tensor([0., 1., 1., 0., 1., 0.]))]


In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+-----------------+------------+
|     Modules     | Parameters |
+-----------------+------------+
| conv1.bias_orig |     6      |
|   conv1.weight  |     54     |
|   conv2.weight  |    864     |
|    conv2.bias   |     16     |
|    fc1.weight   |   48000    |
|     fc1.bias    |    120     |
|    fc2.weight   |   10080    |
|     fc2.bias    |     84     |
|    fc3.weight   |    840     |
|     fc3.bias    |     10     |
+-----------------+------------+
Total Trainable Params: 60074


60074

In [None]:
PATH = './prune_model.pth'
torch.save(model.state_dict(), PATH)