In [1]:
import numpy as np
import mylibrary.nnlib as tnn
import matplotlib.pyplot as plt
import copy

from mpl_toolkits.mplot3d import Axes3D
import matplotlib

import torch
import torch.nn as nn

import mylibrary.datasets as datasets
import prunelib
from tqdm import tqdm
import random

In [2]:
mnist = datasets.MNIST()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

In [3]:
train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [4]:
xx = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
yy = torch.LongTensor(train_label_)

In [5]:
net = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
#     nn.Softmax(dim=1)
)

In [6]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

In [7]:
# for epoch in range(100):
#     yout = net(xx)

#     loss = criterion(yout, yy)
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     error = float(loss)
#     print(epoch, 'Error = ', error)
    
#     with torch.no_grad():
#         yout = net(xx)
#         out = torch.argmax(yout, axis=1)
#         acc = (out.data.numpy() == np.array(train_label_)).astype(np.float).mean()
#         print("Accuracy: ", acc)

In [8]:
list(net._modules.items())

[('0', Linear(in_features=784, out_features=256, bias=True)),
 ('1', ReLU()),
 ('2', Linear(in_features=256, out_features=128, bias=True)),
 ('3', ReLU()),
 ('4', Linear(in_features=128, out_features=64, bias=True)),
 ('5', ReLU()),
 ('6', Linear(in_features=64, out_features=10, bias=True))]

In [9]:
# torch.save({"model":net.state_dict(), "optimizer":optimizer.state_dict()},
#           "./mnist_100_mlp.pth")

In [10]:
checkpoint = torch.load("./mnist_100_mlp.pth")
net.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

## Oracle Pruning Modified

## Define pruning function

In [11]:
class Pruner():
    
    def __init__(self, net, prune_mask=None):
        self.net = net
        self.keys = []
        self.prune_mask = {}
        self.forward_hook = {}
        
        self.activations = []
        
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                self.keys.append(module)

        if prune_mask is not None:
            self.add_prune_mask(prune_mask)
        self.remove_hook()
        
    def add_prune_mask(self, prune_mask):
        for module, pm in zip(self.keys[:-1], prune_mask):
            self.prune_mask[module] = pm.type(torch.float)
        self.prune_mask[self.keys[-1]] = torch.ones(self.keys[-1].out_features, dtype=torch.float)
            
        
    def prune_neurons(self, module, inp, out):
        mask = self.prune_mask[module]
        output = out*mask
        
        self.activations.append(output)
        return output
        
    def forward(self, x, prune_mask=None):
        if prune_mask:
            self.add_prune_mask(prune_mask)
            if len(self.forward_hook) == 0:
                self.add_hook()
        
        y = self.net(x)
        self.remove_hook()
        return y
        
        
    def add_hook(self):
        if len(self.forward_hook) > 0:
            self.remove_hook()
            
        self.forward_hook = {}
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                hook = module.register_forward_hook(self.prune_neurons)
                self.forward_hook[module] = hook
        return
        
    def remove_hook(self):       
        for module in self.forward_hook.keys():
            hook = self.forward_hook[module]
            hook.remove()
        self.forward_hook = {}
        return

## comparing for all

In [28]:
num = 300
pnet = Pruner(net)

yout_normal = net.forward(xx).data.cpu()
print("loss is ", float(criterion(yout_normal, yy)))

taylorfo_mod = prunelib.Importance_TaylorFO_Modified(net, criterion)
for i in range(6):
# for i in range(len(prunelib.taylorfo_mode_list)):
    print(prunelib.taylorfo_mode_list[i])
    importance = taylorfo_mod.compute_significance(xx, yy,
                                                 config=prunelib.taylorfo_mode_config[prunelib.taylorfo_mode_list[i]],
                                                 normalize=True, layerwise_norm=False,
                                                 use_unit_grad=np.random.randint(1024)
                                                  )
    pmask = prunelib.get_pruning_mask(importance, num_prune=num)

    yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
    new_err = criterion(yout_prune, yy)
    print("new error = ", float(new_err))
    
    deviation = ((yout_prune-yout_normal)**2).mean()
    print("deviation = ", float(deviation))
    print()
    
# print("Molchanov parameter")
# mol_parm = prunelib.Importance_Molchanov_2019(net, criterion)
# importance = mol_parm.compute_significance(xx, yy)
# pmask = prunelib.get_pruning_mask(importance, num_prune=num)

# yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
# new_err = criterion(yout_prune, yy)
# print("new error = ", float(new_err))

# deviation = ((yout_prune-yout_normal)**2).mean()
# print("deviation = ", float(deviation))
# print()

# print("APnZ")
# apnz = prunelib.Importance_APoZ(net, criterion)
# importance = apnz.compute_significance(xx, yy)
# pmask = prunelib.get_pruning_mask(importance, num_prune=num)

# yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
# new_err = criterion(yout_prune, yy)
# print("new error = ", float(new_err))

# deviation = ((yout_prune-yout_normal)**2).mean()
# print("deviation = ", float(deviation))
# print()

# print("Magnitude")
# mag = prunelib.Importance_Magnitude(net, criterion)
# importance = mag.compute_significance(xx, yy)
# pmask = prunelib.get_pruning_mask(importance, num_prune=num)

# yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
# new_err = criterion(yout_prune, yy)
# print("new error = ", float(new_err))

# deviation = ((yout_prune-yout_normal)**2).mean()
# print("deviation = ", float(deviation))
# print()

# num = 50
# loss is  0.1028500348329544
# taylorfo
# new error =  0.12499944865703583
# deviation =  1.1074118614196777

# taylorfo_abs
# new error =  0.11183208972215652
# deviation =  0.2707483768463135

# taylorfo_sq
# new error =  0.10807643085718155
# deviation =  0.2193640172481537

# taylorfo_norm
# new error =  0.2512335777282715
# deviation =  3.4444315433502197

# taylorfo_abs_norm
# new error =  0.1116778701543808
# deviation =  0.35614505410194397

# taylorfo_sq_norm
# new error =  0.10743308067321777
# deviation =  0.2153528928756714

loss is  0.1028500348329544
taylorfo
new error =  2.135969400405884
deviation =  25.84116554260254

taylorfo_abs
new error =  2.2881228923797607
deviation =  29.152713775634766

taylorfo_sq
new error =  2.2435989379882812
deviation =  29.51301383972168

taylorfo_norm
new error =  2.634556293487549
deviation =  31.42356300354004

taylorfo_abs_norm
new error =  2.2881228923797607
deviation =  29.152713775634766

taylorfo_sq_norm
new error =  2.244985342025757
deviation =  29.50473976135254



In [57]:
# loss is  0.1028500348329544
# taylorfo
# new error =  2.080207586288452
# deviation =  29.253026962280273

# taylorfo_abs
# new error =  2.3692681789398193
# deviation =  29.090253829956055

# taylorfo_sq
# new error =  2.1357648372650146
# deviation =  29.005739212036133

# taylorfo_norm
# new error =  2.237595558166504
# deviation =  30.961593627929688

# taylorfo_abs_norm
# new error =  2.386812925338745
# deviation =  29.190523147583008

# taylorfo_sq_norm
# new error =  2.1146740913391113
# deviation =  29.0556640625

513

In [13]:
# taylorfo_norm = prunelib.Importance_TaylorFO_Normalized(net, criterion)
# for i in range(len(prunelib.taylorfo_mode_list)):
#     print(prunelib.taylorfo_mode_list[i])
#     importance = taylorfo_norm.compute_significance(xx, yy,
#                                                  config=prunelib.taylorfo_mode_config[prunelib.taylorfo_mode_list[i]],
#                                                  normalize=True, layerwise_norm=False)
#     pmask = prunelib.get_pruning_mask(importance, num_prune=num)

#     yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
#     new_err = criterion(yout_prune, yy)
#     print("new error = ", float(new_err))
    
#     deviation = ((yout_prune-yout_normal)**2).mean()
#     print("deviation = ", float(deviation))
#     print()

taylorfo
new error =  2.4779293537139893
deviation =  28.82784652709961

taylorfo_abs
new error =  0.9035003185272217
deviation =  15.854421615600586

taylorfo_sq
new error =  2.304408311843872
deviation =  31.76415252685547

taylorfo_norm
new error =  2.4779293537139893
deviation =  28.82784652709961

taylorfo_abs_norm
new error =  0.9035003185272217
deviation =  15.854421615600586

taylorfo_sq_norm
new error =  0.9820285439491272
deviation =  19.799896240234375

taylorfo_nolin
new error =  2.4779293537139893
deviation =  28.82784652709961

taylorfo_abs_nolin
new error =  0.9035003185272217
deviation =  15.854421615600586

taylorfo_sq_nolin
new error =  2.304408311843872
deviation =  31.76415252685547

taylorfo_norm_nolin
new error =  2.4779293537139893
deviation =  28.82784652709961

taylorfo_abs_norm_nolin
new error =  0.9035003185272217
deviation =  15.854421615600586

taylorfo_sq_norm_nolin
new error =  0.9820285439491272
deviation =  19.799896240234375



In [14]:
# grad = torch.randn(100, 10)
# torch.norm(grad, dim=1).shape

torch.Size([100])

In [None]:
# class Importance_Molchanov2(prunelib.Importance):

#     def __init__(self, net, criterion):
#         self.net = net
#         self.criterion = criterion
        
#         self.inputs = {}
#         self.gradients = {}
#         self.forward_hook = {}
#         self.backward_hook = {}
#         self.keys = []
#         pass

#     def add_hook(self):
#         self.inputs = {}
#         self.gradients = {}
#         self.forward_hook = {}
#         self.backward_hook = {}
#         self.keys = []
        
#         for name, module in list(self.net._modules.items()):
#             if isinstance(module, torch.nn.Linear):
#                 hook = module.register_backward_hook(self.capture_gradients)
#                 self.backward_hook[module] = hook
#                 hook = module.register_forward_hook(self.capture_inputs)
#                 self.forward_hook[module] = hook
                
#                 self.inputs[module] = None
#                 self.gradients[module] = None
#                 self.keys.append(module)
        
#     def remove_hook(self):
#         for module in self.keys:
#             hook = self.forward_hook[module]
#             hook.remove()
#             hook = self.backward_hook[module]
#             hook.remove()
    
#     def capture_inputs(self, module, inp, out):
# #         print("inp >>")
# #         for i in inp:
# #             if i is not None:
# #                 print(i.shape)
# #         print("<< inp")
#         self.inputs[module] = inp[0].data
        
#     def capture_gradients(self, module, gradi, grado):
# #         print(grado[0].shape)
# #         print(module.weight.shape)
# #         print(module.bias.shape)
#         self.gradients[module] = grado[0]
        
#     def gather_inputs_gradients(self, x, t):
#         self.add_hook()

#         self.net.zero_grad()
#         y = self.net(x)
        
#         error = self.criterion(y,t)
#         error.backward()
        
#         self.remove_hook()
#         return
    
#     def compute_significance(self, x, t, normalize=True):
#         self.gather_inputs_gradients(x, t)

# #         importance = [0]*len(self.keys)
#         importance = []
#         for module in self.keys:
#             ## compute weight and bias gradients
#             inp = self.inputs[module]
#             inp = inp.reshape(inp.shape[0], 1, -1)
#             grd = self.gradients[module]
#             grd = inp.reshape(grd.shape[0], -1, 1)
            
#             wgrad = torch.matmul(grd, inp)#.pow(2).sum(dim=0) + 
#             w_ = module.weight.data.reshape(1, module.weight.data.shape[0], module.weight.data.shape[1])
#             wz = w_*wgrad
#             print(wz.shape)
#             bz = module.bias.data.reshape(1, -1)@grd.reshape(grd.shape[0], -1)
#             print(bz.shape)
#             z = wz.pow(2).sum(dim=[0,2]) + bz.pow(2).sum(dim=0)
#             print(z.shape)
#             print()
#             importance.append(z)
# #             for i in tqdm(range(len(x))):
# #                 dwi = grd[i:i+1].t()@inp[i:i+1]
# #                 dbi = grd[i]
# #                 z = (module.weight.data*dwi).pow(2).sum(dim=1) + \
# #                     (module.bias*dbi).pow(2)
# #                 impo += z
# #             impo = impo/len(x)
# #             importance.append(impo)

#         importance = importance[:-1]
#         if normalize:
#             sums = 0
#             count = 0
#             for imp in importance:
#                 sums += imp.sum()
#                 count += len(imp)
#             divider = sums/count ## total importance is number of neurons
#             for i in range(len(importance)):
#                 importance[i] = importance[i]/divider
            
#         return importance

In [None]:
# inp = torch.randn(60000, 784)
# grd = torch.randn(60000, 256)

# m1 = inp.reshape(inp.shape[0], 1, -1)
# m2 = grd.reshape(grd.shape[0], -1, 1)
# torch.matmul(m2, m1).shape

In [None]:
# for i in range(len(inp)):
#     print((grd[i:i+1].t()@inp[i:i+1]))

In [None]:
# print("Molchanov parameter 2") ## overflows the memory

# mol_parm = Importance_Molchanov2(net, criterion)
# importance = mol_parm.compute_significance(xx, yy)

In [None]:
class Importance_Molchanov2_1(prunelib.Importance):

    def __init__(self, net, criterion):
        self.net = net
        self.criterion = criterion
        self.keys = []
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                self.keys.append(module)
        
    def compute_significance(self, x, t, normalize=True, batch_size=32):

        importance = [0]*len(self.keys)
        bstrt = list(range(0, len(x), batch_size))
        bstop = bstrt[1:]+[len(x)]
        for i in tqdm(range(len(bstrt))):
            self.net.zero_grad()
            y = self.net(x[bstrt[i]:bstop[i]])
            error = self.criterion(y,t[bstrt[i]:bstop[i]])
            error.backward()
        
            ## compute importance for each input
            for j, module in enumerate(self.keys):
                z = (module.weight.data*module.weight.grad).pow(2).sum(dim=1) + \
                    (module.bias*module.bias.grad).pow(2)
                importance[j] += z
                
        ## compute mean
        for i, module in enumerate(self.keys):
            importance[i] = importance[i]/len(bstrt) 


        importance = importance[:-1]
        if normalize:
            sums = 0
            count = 0
            for imp in importance:
                sums += imp.sum()
                count += len(imp)
            divider = sums/count ## total importance is number of neurons
            for i in range(len(importance)):
                importance[i] = importance[i]/divider
            
        return importance

In [None]:
print("Molchanov parameter 2.1")

mol_parm = Importance_Molchanov2_1(net, criterion)
importance = mol_parm.compute_significance(xx, yy, batch_size=32)
pmask = prunelib.get_pruning_mask(importance, num_prune=num)

yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
new_err = criterion(yout_prune, yy)
print("new error = ", float(new_err))

deviation = ((yout_prune-yout_normal)**2).mean()
print("deviation = ", float(deviation))
print()

In [None]:
# for num=200
# 1024
# new error =  1.1853246688842773
# deviation =  21.11980628967285

# 512
# new error =  1.3359383344650269
# deviation =  22.091835021972656

# 256
# new error =  1.3257719278335571
# deviation =  22.431970596313477

# 128
# new error =  1.3867021799087524
# deviation =  22.838653564453125

# 64
# new error =  1.3867021799087524
# deviation =  22.838653564453125

# 32
# new error =  1.3867021799087524
# deviation =  22.838653564453125

# 16
# new error =  1.3867021799087524
# deviation =  22.838653564453125

# 8
# new error =  1.421226978302002
# deviation =  23.039220809936523

# 4
# new error =  1.421226978302002
# deviation =  23.039220809936523

### Magnitude based and APoZ(Relu only)

In [None]:
class Importance_APoZ(prunelib.Importance):

    def __init__(self, net, criterion):
        self.net = net
        self.criterion = criterion
        
        self.activations = {}
        self.forward_hook = {}
        self.keys = []
        pass

    def add_hook(self):
        self.activations = {}
        self.forward_hook = {}
        self.keys = []
        
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.ReLU):
                hook = module.register_forward_hook(self.capture_inputs)
                self.forward_hook[module] = hook
                
                self.activations[module] = None
                self.keys.append(module)
        
    def remove_hook(self):
        for module in self.keys:
            hook = self.forward_hook[module]
            hook.remove()
    
    def capture_inputs(self, module, inp, out):
        self.activations[module] = out.data
        
    def gather_activations(self, x, t):
        self.add_hook()

        self.net.zero_grad()
        y = self.net(x)
        
        self.remove_hook()
        return
    
    def compute_significance(self, x, t, normalize=True):
        self.gather_activations(x, t)

        importance = []
        for module in self.keys:
            apnz = torch.sum(self.activations[module] > 0., dim=0, dtype=torch.float)
            importance.append(apnz)

        if normalize:
            sums = 0
            count = 0
            for imp in importance:
                sums += imp.sum()
                count += len(imp)
            divider = sums/count ## total importance is number of neurons
            for i in range(len(importance)):
                importance[i] = importance[i]/divider
            
        return importance

In [None]:
xx.shape

In [None]:
print("APnZ")

apnz = Importance_APoZ(net, criterion)
importance = apnz.compute_significance(xx, yy)
pmask = prunelib.get_pruning_mask(importance, num_prune=num)

yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
new_err = criterion(yout_prune, yy)
print("new error = ", float(new_err))

deviation = ((yout_prune-yout_normal)**2).mean()
print("deviation = ", float(deviation))
print()

In [None]:
apnz.remove_hook()

In [None]:
class Importance_Magnitude(prunelib.Importance):

    def __init__(self, net, criterion=None):
        self.net = net
        self.keys = []
        for name, module in list(self.net._modules.items()):
            if isinstance(module, torch.nn.Linear):
                self.keys.append(module)
        
    def compute_significance(self, x=None, t=None, normalize=True):

        importance = []
        for module in self.keys:
            z = torch.norm(module.weight.data, p=2, dim=1)
            importance.append(z)

        importance = importance[:-1]
        if normalize:
            sums = 0
            count = 0
            for imp in importance:
                sums += imp.sum()
                count += len(imp)
            divider = sums/count ## total importance is number of neurons
            for i in range(len(importance)):
                importance[i] = importance[i]/divider
            
        return importance

In [None]:
print("Magnitude")

mag = Importance_Magnitude(net, criterion)
importance = mag.compute_significance(xx, yy)
pmask = prunelib.get_pruning_mask(importance, num_prune=num)

yout_prune = pnet.forward(xx, prune_mask=pmask).data.cpu()
new_err = criterion(yout_prune, yy)
print("new error = ", float(new_err))

deviation = ((yout_prune-yout_normal)**2).mean()
print("deviation = ", float(deviation))
print()

In [None]:
# loss is  0.1028500348329544
# oracle
# new error =  1.8523311614990234
# deviation =  24.276596069335938

# oracle_abs
# new error =  1.2125073671340942
# deviation =  14.391016960144043

# oracle_sq
# new error =  0.8678733706474304
# deviation =  17.924470901489258

# oracle_norm
# new error =  2.3006534576416016
# deviation =  27.330440521240234

# oracle_abs_norm
# new error =  1.2079159021377563
# deviation =  14.306313514709473

# oracle_sq_norm
# new error =  0.8907394409179688
# deviation =  18.013774871826172

# oracle_nolin
# new error =  1.917508840560913
# deviation =  24.35315704345703

# oracle_abs_nolin
# new error =  1.8538734912872314
# deviation =  16.1265926361084

# oracle_sq_nolin
# new error =  0.8453637957572937
# deviation =  14.958662033081055

# oracle_norm_nolin
# new error =  2.6378610134124756
# deviation =  29.217300415039062

# oracle_abs_norm_nolin
# new error =  1.8637737035751343
# deviation =  16.142126083374023

# oracle_sq_norm_nolin
# new error =  0.8662921786308289
# deviation =  15.479593276977539