In [2]:
import numpy as np
import mylibrary.nnlib as tnn
import matplotlib.pyplot as plt
import copy

from mpl_toolkits.mplot3d import Axes3D
import matplotlib

import torch
import torch.nn as nn

import mylibrary.datasets as datasets
import prunelib
from tqdm import tqdm
import random

import pickle

In [3]:
mnist = datasets.MNIST()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

In [4]:
train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [5]:
xx = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
yy = torch.LongTensor(train_label_)

## Compare

In [6]:
input_dim = 784
output_dim = 10

In [7]:
"""
settings:
1,2,3 -> net = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

4,5 -> net = nn.Sequential(
    nn.Linear(784, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 10),
)

6, 7 -> net = nn.Sequential(
    nn.Linear(784, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 10),
)
8, 9 -> net = nn.Sequential(
    nn.Linear(784, 400),
    nn.ReLU(),
    nn.Linear(400, 300),
    nn.ReLU(),
    nn.Linear(300, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10),
)
"""
print("Nets")

Nets


In [8]:
config0 = [256, 128, 64]
config1 = [100, 100, 100]
config2 = [200, 100, 100, 100]
config3 = [400, 300, 200, 100]

layer_dims = config0

In [9]:
def get_mlp(config, batch_norm=False, final_activation=None):
    config = [input_dim]+config
    layers = []
    for i in range(len(config)-1):
        l = nn.Linear(config[i], config[i+1])
        layers.append(l)
        if batch_norm:
            layers.append(nn.BatchNorm1d(config[i+1]))
        layers.append(nn.ReLU())
            
    l = nn.Linear(config[-1], output_dim)
    layers.append(l)
    if final_activation:
        layers.append(final_activation)
    return nn.Sequential(*layers)

In [10]:
expindx = 8

In [11]:
net = get_mlp(config0, batch_norm=True)
net

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=128, bias=True)
  (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=64, bias=True)
  (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): Linear(in_features=64, out_features=10, bias=True)
)

In [12]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

In [13]:
# for epoch in range(60):
#     yout = net(xx)

#     loss = criterion(yout, yy)
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     error = float(loss)
#     print(epoch, 'Error = ', error)
    
#     with torch.no_grad():
#         yout = net(xx)
#         out = torch.argmax(yout, axis=1)
#         acc = (out.data.numpy() == np.array(train_label_)).astype(np.float).mean()
#         print("Accuracy: ", acc)

In [14]:
# list(net._modules.items())

In [15]:
# torch.save({"model":net.state_dict(), "optimizer":optimizer.state_dict()},
#           "./mnist_100_mlp_bn.pth")

In [16]:
checkpoint = torch.load("./mnist_100_mlp_bn.pth")
net.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

In [17]:
# net = net.train()

## Convert BN network to Linear only

In [18]:
def remove_batchnorm(net, layer_dims):
    net_ = get_mlp(layer_dims)
    count = 0
    for name, module in list(net._modules.items()):
        if isinstance(module, nn.BatchNorm1d):
            count += 1
            
    i = 0
    j = 0
    print(count)
    for _ in range(count):
        gamma = net[i+1].weight.data
        beta = net[i+1].bias.data
        mean = net[i+1].running_mean
        var = torch.sqrt(net[i+1].running_var)

        w = net[i].weight.data
        b = net[i].bias.data

        newW = (gamma/var).reshape(-1, 1)*w
        newb = beta + gamma/var*(b-mean)

        net_[j].weight.data *= 0.
        net_[j].bias.data *= 0

        net_[j].weight.data += newW
        net_[j].bias.data += newb

        i = i + 3
        j = j + 2
        
    net_[-1].weight.data *= 0
    net_[-1].bias.data *= 0
    net_[-1].weight.data += net[-1].weight.data
    net_[-1].bias.data += net[-1].bias.data
    
    return  net_

In [19]:
net_ = remove_batchnorm(net, layer_dims)
net_

3


Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ReLU()
  (6): Linear(in_features=64, out_features=10, bias=True)
)

In [20]:
## accuracy of new network
# yout = net.eval()(xx)
yout = net_(xx)
out = torch.argmax(yout, axis=1)
acc = (out.data.numpy() == np.array(train_label_)).astype(np.float).mean()
print("Accuracy: ", acc)

Accuracy:  0.9959166666666667


## Oracle Pruning Modified

In [62]:
class Importance_TaylorFO_Modified_BN(prunelib.Importance):
    
    def __init__(self, net, criterion, config=None):
        self.net = net
        self.config = config
        self.criterion = criterion
        
        self.activations = {}
        self.gradients = {}
        self.forward_hook = {}
        self.backward_hook = {}
        self.keys = []
        
        pass

    def add_hook(self):
        self.activations = {}
        self.gradients = {}
        self.forward_hook = {}
        self.backward_hook = {}
        self.keys = []
        
        modules = list(self.net._modules.items())
        for name, module in modules:
            if isinstance(module, torch.nn.BatchNorm1d):
#             if isinstance(module, torch.nn.Linear):
                hook = module.register_backward_hook(self.capture_gradients)
                self.backward_hook[module] = hook
                hook = module.register_forward_hook(self.capture_inputs)
                self.forward_hook[module] = hook
                
                self.activations[module] = None
                self.gradients[module] = None
                self.keys.append(module)
                
        for name, module in reversed(modules):
            if isinstance(module, torch.nn.Linear):
                hook = module.register_backward_hook(self.capture_gradients)
                self.backward_hook[module] = hook
                hook = module.register_forward_hook(self.capture_inputs)
                self.forward_hook[module] = hook
                
                self.activations[module] = None
                self.gradients[module] = None
                self.keys.append(module)
                break
                
    def remove_hook(self):
        for module in self.keys:
            hook = self.forward_hook[module]
            hook.remove()
            hook = self.backward_hook[module]
            hook.remove()
    
    def capture_inputs(self, module, inp, out):
        self.activations[module] = out.data
        
    def capture_gradients(self, module, gradi, grado):
        self.gradients[module] = grado[0]
        
    def gather_inputs_gradients(self, x, t):
        self.add_hook()

        self.net.zero_grad()
        y = self.net(x)
        error = self.criterion(y,t)
        error.backward()
        
        self.remove_hook()
        return
    
    
    def compute_significance(self, x, t, config=None, normalize=True, layerwise_norm=False):
        self.gather_inputs_gradients(x, t)
        
        if config is None:
            if self.config is None:
                raise ValueError("config is not known. Please specify the config.") 
            else:
                config = self.config
        
        ## compute importance score
        importance = []
        if config["grad_rescale"]:
            scaler = torch.norm(self.gradients[self.keys[-1]], p=2, dim=1, keepdim=True) + 1e-5

        for module in self.keys[:-1]:
            z = self.activations[module] * self.gradients[module]
            if config["grad_rescale"]:
                z = z / scaler
            if config["imp_norm"] == "abs":
                z = z.abs()
            elif config["imp_norm"] == "sq":
                z = z.pow(2)

            z = z.sum(dim=0).abs()
            if not config["allow_linear"]:
                apnz = torch.mean(self.activations[module] > 0., dim=0, dtype=torch.float)
                z = z*(1-apnz) * 4 ## tried on desmos.

            if layerwise_norm:
                z = z / torch.norm(z, p=2)

            importance.append(z)

        if normalize:
            sums = 0
            count = 0
            for imp in importance:
                sums += imp.sum()
                count += len(imp)
            divider = sums/count ## total importance is number of neurons
            for i in range(len(importance)):
                importance[i] = importance[i]/divider
            
        
        del self.activations[self.keys[-1]]
#         self.activations = {}
        self.gradients = {}
        self.forward_hook = {}
        self.backward_hook = {}
        
        aponz, std = self.get_aponz(True, True)
        return importance, aponz, std
    
    def get_aponz(self, return_std=True, remove_activations = True):
        if len(self.activations) < 1:
            print("Activation has not been accumulated.. run compute_significance function")
            return
        aponz = []
        std = []
        for module in self.keys[:-1]:
            apnz = torch.mean(self.activations[module] > 0., dim=0, dtype=torch.float)
            aponz.append(apnz)
            if return_std:
                std.append(self.activations[module].std(dim=0))
        
        if remove_activations:
            self.activations = {}
        
        if return_std:
            return aponz, std
        return aponz

        

class Importance_Molchanov_BN(prunelib.Importance):

    def __init__(self, net, criterion):
        self.net = net
        self.criterion = criterion
        self.keys = []
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.BatchNorm1d):
                self.keys.append(module)
        
    def compute_significance(self, x, t, normalize=True, batch_size=32):

        importance = [0]*len(self.keys)
        bstrt = list(range(0, len(x), batch_size))
        bstop = bstrt[1:]+[len(x)]
        for i in tqdm(range(len(bstrt))):
            self.net.zero_grad()
            y = self.net(x[bstrt[i]:bstop[i]])
            error = self.criterion(y,t[bstrt[i]:bstop[i]])
            error.backward()
        
            ## compute importance for each input
            for j, module in enumerate(self.keys):
                z = (module.weight.data*module.weight.grad +\
                     module.bias.data*module.bias.grad).pow(2)
                importance[j] += z
                
        ## compute mean
        for i, module in enumerate(self.keys):
            importance[i] = importance[i]/len(bstrt) 

        if normalize:
            sums = 0
            count = 0
            for imp in importance:
                sums += imp.sum()
                count += len(imp)
            divider = sums/count ## total importance is number of neurons
            for i in range(len(importance)):
                importance[i] = importance[i]/divider
            
        return importance

In [63]:
torch.randn(20, 10).std(dim=0).shape

torch.Size([10])

## Define pruning function

In [64]:
class Pruner():
    
    def __init__(self, net, prune_mask=None):
        self.net = net
        self.keys = []
        self.prune_mask = {}
        self.forward_hook = {}
        
        self.activations = []
        
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                self.keys.append(module)

        if prune_mask is not None:
            self.add_prune_mask(prune_mask)
        self.remove_hook()
        
    def add_prune_mask(self, prune_mask):
        for module, pm in zip(self.keys[:-1], prune_mask):
            self.prune_mask[module] = pm.type(torch.float)
        self.prune_mask[self.keys[-1]] = torch.ones(self.keys[-1].out_features, dtype=torch.float)
            
        
    def prune_neurons(self, module, inp, out):
        mask = self.prune_mask[module]
        output = out*mask
        
        self.activations.append(output)
        return output
        
    def forward(self, x, prune_mask=None):
        if prune_mask:
            self.add_prune_mask(prune_mask)
            if len(self.forward_hook) == 0:
                self.add_hook()
        
        y = self.net(x)
        self.remove_hook()
        return y
        
        
    def add_hook(self):
        if len(self.forward_hook) > 0:
            self.remove_hook()
            
        self.forward_hook = {}
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                hook = module.register_forward_hook(self.prune_neurons)
                self.forward_hook[module] = hook
        return
        
    def remove_hook(self):       
        for module in self.forward_hook.keys():
            hook = self.forward_hook[module]
            hook.remove()
        self.forward_hook = {}
        self.prune_mask = {}
        self.activations = []
        return

## Collecting importance and aponz

In [65]:
methods = []
classes = []

## taylor fo modified
methods = prunelib.taylorfo_mode_list[:3]
classes += [Importance_TaylorFO_Modified_BN(net, criterion, config=prunelib.taylorfo_mode_config[method]) for method in methods]


## Molchanov_group, APnZ, Magnitude
methods += ["Molchanov_group", "Molchanov_BN", "APnZ"]
classes += [
            prunelib.Importance_Molchanov_2019(net, criterion),
            Importance_Molchanov_BN(net, criterion),
            prunelib.Importance_APoZ(net, criterion),
            ]

In [66]:
## gather all importances
importances = []
aponzs = []
stds = []
for i in range(len(methods)):
    print(methods[i])
    imp = classes[i].compute_significance(xx, yy)
    if methods[i].startswith("taylorfo"):
        imp, aponz, std = imp
        aponzs.append(aponz)
        stds.append(std)
    importances.append(imp)

taylorfo
taylorfo_abs
taylorfo_sq


  1%|          | 12/1875 [00:00<00:16, 114.17it/s]

Molchanov_group


100%|██████████| 1875/1875 [00:15<00:00, 123.38it/s]
  1%|          | 11/1875 [00:00<00:17, 107.86it/s]

Molchanov_BN


100%|██████████| 1875/1875 [00:14<00:00, 128.50it/s]


APnZ


In [67]:
aponzs[0]

[tensor([0.4595, 0.4791, 0.5166, 0.4861, 0.4944, 0.5337, 0.4693, 0.3742, 0.5517,
         0.4649, 0.5320, 0.4928, 0.4915, 0.4889, 0.5059, 0.4861, 0.5688, 0.4711,
         0.4488, 0.5264, 0.4800, 0.5444, 0.5210, 0.4251, 0.5261, 0.5160, 0.5289,
         0.4850, 0.4975, 0.5541, 0.4770, 0.4448, 0.4113, 0.4777, 0.4723, 0.4626,
         0.4122, 0.4629, 0.5275, 0.4990, 0.5132, 0.4519, 0.5701, 0.4670, 0.5548,
         0.4778, 0.5104, 0.5284, 0.4530, 0.5938, 0.3519, 0.4909, 0.4295, 0.4384,
         0.5258, 0.4792, 0.4690, 0.4087, 0.5297, 0.4581, 0.4712, 0.4616, 0.4160,
         0.5481, 0.4981, 0.5123, 0.4907, 0.4775, 0.4674, 0.5692, 0.4992, 0.4146,
         0.4655, 0.4683, 0.4863, 0.4900, 0.5031, 0.5070, 0.4227, 0.4076, 0.4235,
         0.4841, 0.4986, 0.5068, 0.4988, 0.4849, 0.5232, 0.4821, 0.4702, 0.4958,
         0.4115, 0.4232, 0.4593, 0.4891, 0.3406, 0.4801, 0.4864, 0.4383, 0.5098,
         0.4651, 0.4922, 0.5456, 0.4212, 0.4891, 0.4661, 0.4977, 0.5504, 0.5469,
         0.4725, 0.4326, 0.4

In [81]:
for bns in classes[0].keys[:-1]:
#     print(bns.weight.data)
    print(((-bns.bias.data*3/bns.weight.data)+1)/2)

tensor([0.5561, 0.6352, 0.5023, 0.5189, 0.5694, 0.4510, 0.6529, 0.6222, 0.5174,
        0.6005, 0.5191, 0.4678, 0.5526, 0.5611, 0.5083, 0.5567, 0.5057, 0.5329,
        0.6497, 0.5042, 0.6027, 0.5472, 0.6309, 0.5837, 0.5707, 0.5306, 0.5622,
        0.5831, 0.5243, 0.4830, 0.5475, 0.5679, 0.5639, 0.4948, 0.5485, 0.5767,
        0.5637, 0.4863, 0.6139, 0.5428, 0.5831, 0.5952, 0.5368, 0.5031, 0.4752,
        0.5230, 0.4949, 0.5379, 0.5413, 0.5656, 0.6220, 0.6306, 0.5856, 0.5674,
        0.6088, 0.6318, 0.5420, 0.6031, 0.4025, 0.5256, 0.5206, 0.5399, 0.6209,
        0.5127, 0.4317, 0.5459, 0.5434, 0.6067, 0.5251, 0.3440, 0.5300, 0.4433,
        0.5023, 0.6380, 0.5146, 0.5759, 0.5137, 0.5357, 0.5722, 0.5486, 0.7055,
        0.5236, 0.5066, 0.5864, 0.5412, 0.5445, 0.4892, 0.6100, 0.5556, 0.4538,
        0.6001, 0.5635, 0.6861, 0.6062, 0.6894, 0.6244, 0.5819, 0.6021, 0.5455,
        0.6175, 0.5067, 0.6357, 0.7128, 0.6425, 0.5474, 0.4082, 0.5430, 0.5556,
        0.6264, 0.5220, 0.5381, 0.5396, 

In [68]:
stds[0]

[[tensor([1.0047, 1.0040, 0.9974, 0.9939, 0.9754, 1.0131, 0.9484, 1.0526, 1.0099,
          1.0419, 0.9687, 1.0016, 1.0218, 1.0697, 1.0319, 0.9778, 0.9604, 0.9736,
          0.9775, 0.9658, 0.9999, 1.0364, 0.9842, 0.9982, 1.0128, 0.9619, 0.9997,
          0.9969, 1.0345, 1.0102, 1.0058, 1.0090, 1.0048, 1.0381, 1.0308, 0.9535,
          1.0027, 1.0268, 0.9707, 0.9913, 0.9427, 1.0130, 0.9583, 1.0028, 1.0203,
          1.0211, 1.0575, 1.0910, 1.0033, 0.9904, 1.0057, 0.9605, 1.0136, 0.9985,
          1.0124, 0.9619, 1.0070, 1.0375, 1.0939, 1.0133, 1.0059, 1.0350, 0.9605,
          0.9993, 1.0747, 1.0371, 0.9977, 0.9568, 1.0040, 0.9930, 0.9859, 1.0164,
          1.0248, 0.9603, 0.9895, 0.9816, 0.9840, 0.9857, 0.9541, 1.0218, 0.9861,
          0.9882, 1.0116, 0.9890, 0.9961, 0.9799, 1.0911, 0.9819, 1.0171, 1.0628,
          0.9829, 0.9850, 0.9589, 0.9577, 0.9977, 0.9939, 1.0112, 0.9768, 1.0686,
          1.0034, 1.0120, 1.0165, 0.9852, 0.9822, 0.9795, 1.0676, 0.9930, 0.9687,
          0.8955

In [71]:
for bns in classes[0].keys[:-1]:
    print(bns.weight.data)
#     print(bns.bias)

tensor([1.0047, 1.0040, 0.9974, 0.9939, 0.9754, 1.0131, 0.9484, 1.0526, 1.0099,
        1.0419, 0.9687, 1.0016, 1.0218, 1.0697, 1.0320, 0.9778, 0.9604, 0.9736,
        0.9775, 0.9658, 0.9999, 1.0364, 0.9842, 0.9982, 1.0128, 0.9619, 0.9997,
        0.9969, 1.0345, 1.0102, 1.0058, 1.0090, 1.0048, 1.0382, 1.0308, 0.9535,
        1.0027, 1.0268, 0.9707, 0.9913, 0.9427, 1.0130, 0.9583, 1.0028, 1.0203,
        1.0211, 1.0575, 1.0910, 1.0033, 0.9904, 1.0057, 0.9605, 1.0137, 0.9985,
        1.0124, 0.9619, 1.0070, 1.0375, 1.0939, 1.0134, 1.0059, 1.0350, 0.9605,
        0.9993, 1.0747, 1.0371, 0.9977, 0.9568, 1.0040, 0.9930, 0.9859, 1.0164,
        1.0248, 0.9603, 0.9895, 0.9816, 0.9840, 0.9857, 0.9541, 1.0218, 0.9861,
        0.9882, 1.0117, 0.9890, 0.9961, 0.9799, 1.0911, 0.9819, 1.0171, 1.0628,
        0.9829, 0.9850, 0.9589, 0.9577, 0.9977, 0.9939, 1.0113, 0.9768, 1.0686,
        1.0034, 1.0120, 1.0165, 0.9852, 0.9822, 0.9796, 1.0676, 0.9930, 0.9687,
        0.8955, 0.9724, 1.0391, 1.0722, 