In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm
# from sklearn import datasets

In [2]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [3]:
from tqdm import tqdm
import os, time, sys
import json

In [4]:
import dtnnlib as dtnn

In [5]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

# train_dataset = datasets.MNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [6]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [7]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50])


In [8]:
h = 200
model = nn.Sequential(
            dtnn.DistanceTransform_MinExp(784, h),
            nn.LeakyReLU(),
            nn.Linear(h, 10))

In [9]:
model.to(device)

Sequential(
  (0): DistanceTransform_MinExp()
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=200, out_features=10, bias=True)
)

In [10]:
# model[0].set_centroid_to_data_maxdist(train_loader)
# model[0].set_centroid_to_data(train_loader)
# model[0].set_centroid_to_data_randomly(train_loader)

## Randomly

In [11]:
N = model[0].centers.shape[0]
new_center = []
new_labels = []
count = 0
for i, (xx, yy) in enumerate(train_loader):
    xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
    if count+xx.shape[0] < N:
        new_center.append(xx)
        new_labels.append(yy)
        count += xx.shape[0]
    elif count >= N:
        break
    else:
        new_center.append(xx[:N-count])
        new_labels.append(yy[:N-count])
        count = N
        break
        
new_center = torch.cat(new_center, dim=0)
new_labels = torch.cat(new_labels, dim=0)

In [12]:
new_center.shape, new_labels.shape

(torch.Size([200, 784]), torch.Size([200]))

## Set


In [13]:
weights = torch.zeros(len(new_labels), 10)
for i in range(len(new_labels)):
    weights[i, new_labels[i]] = 1.
# weights

In [14]:
weights.shape

torch.Size([200, 10])

In [15]:
model[0].centers.data = new_center.to(model[0].centers.device)
model[-1].weight.data = weights.t().to(model[-1].weight.data)

In [16]:
best_acc = -1
def test(epoch, model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    acc = 100.*correct/total
    return acc

In [17]:
model.eval()

Sequential(
  (0): DistanceTransform_MinExp()
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=200, out_features=10, bias=True)
)

In [18]:
criterion = nn.CrossEntropyLoss()

In [19]:
test_acc = test(0, model)
test_acc

[Test] 0 Loss: 1.110 | Acc: 69.350 6935/10000


69.35

## Add new centers to the model

In [20]:
N_search = 100

In [21]:
def get_random_training_samples(N):
    new_center = []
    new_labels = []
    count = 0
    for i, (xx, yy) in enumerate(train_loader):
        xx = xx.reshape(xx.shape[0], -1)
        if count+xx.shape[0] < N:
            new_center.append(xx)
            new_labels.append(yy)
            count += xx.shape[0]
        elif count >= N:
            break
        else:
            new_center.append(xx[:N-count])
            new_labels.append(yy[:N-count])
            count = N
            break

    new_center = torch.cat(new_center, dim=0)
    new_labels = torch.cat(new_labels, dim=0)
    
    weights = torch.zeros(len(new_labels), 10)
    for i in range(len(new_labels)):
        weights[i, new_labels[i]] = 1.
    
    return new_center.to(device), weights.to(device)

In [22]:
get_random_training_samples(2)

(tensor([[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -0.9922, -0.9843,  ..., -1.0000, -1.0000, -1.0000]]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

In [23]:
def add_neurons_to_model(model, centers, values):
    c = torch.cat((model[0].centers.data, centers), dim=0)
    v = torch.cat((model[-1].weight.data, values.t()), dim=1)
    s = torch.cat((model[0].scaler.data, torch.ones(1, len(centers))*6/3), dim=1)
    model[0].centers.data = c
    model[-1].weight.data = v
    model[0].scaler.data = s
    pass

In [24]:
model[0].centers.data.shape, model[-1].weight.data.shape

(torch.Size([200, 784]), torch.Size([10, 200]))

In [25]:
add_neurons_to_model(model, *get_random_training_samples(N_search))

In [26]:
model[0].centers.data.shape, model[-1].weight.data.shape

(torch.Size([300, 784]), torch.Size([10, 300]))

In [27]:
test_acc2 = test(0, model)
test_acc2, test_acc

[Test] 0 Loss: 1.101 | Acc: 70.090 7009/10000


(70.09, 69.35)

## Calculate Neuron Significance

In [28]:
outputs, gradients = None, None
def capture_outputs(module, inp, out):
    global outputs
#     print(inp)
    outputs = out.data.cpu()

def capture_gradients(module, gradi, grado):
    global gradients
#     print(gradi, '\n')
#     print(grado)
    gradients = grado[0].data.cpu()
        
forw_hook = model[0].register_forward_hook(capture_outputs)
back_hook = model[0].register_backward_hook(capture_gradients)
# back_hook = model[0].register_full_backward_hook(capture_gradients)


def remove_hook():
    back_hook.remove()
    forw_hook.remove()

In [29]:
significance = torch.zeros(model[0].centers.shape[0])

In [30]:
yout = model(xx)



In [31]:
list(model.parameters())

[Parameter containing:
 tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]], requires_grad=True),
 Parameter containing:
 tensor([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
          2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 

In [32]:
def none_grad():
    for p in model.parameters():
        p.grad = None

In [33]:
none_grad()
# grad = torch.randn_like(yout)
# grad = grad/torch.norm(grad, dim=1, keepdim=True)
# yout.backward(gradient=grad, retain_graph=False)
criterion(yout, yy).backward()

In [34]:
remove_hook()

In [35]:
outputs.shape, gradients.shape

(torch.Size([50, 300]), torch.Size([50, 300]))

In [36]:
with torch.no_grad():
    significance += torch.sum((outputs*gradients)**2, dim=0)
significance

tensor([3.4651e-09, 5.5334e-11, 2.0848e-10, 1.8164e-10, 2.4516e-08, 4.7279e-09,
        6.8161e-11, 6.0861e-12, 1.3654e-09, 3.2271e-09, 6.1094e-09, 4.4848e-08,
        1.7369e-08, 7.6331e-11, 9.0758e-07, 4.7269e-08, 4.4653e-09, 4.4455e-08,
        1.1083e-08, 1.4347e-07, 2.8366e-09, 2.4867e-07, 7.2180e-08, 1.2822e-12,
        1.2969e-08, 1.2357e-09, 8.8349e-08, 2.0735e-10, 2.6791e-09, 4.8465e-11,
        3.5571e-07, 1.1953e-08, 7.3189e-10, 1.8456e-11, 6.9988e-08, 4.4973e-08,
        5.6264e-08, 7.9606e-10, 3.5879e-08, 1.8810e-12, 2.4822e-08, 5.2071e-09,
        6.4649e-10, 2.0374e-09, 4.7842e-08, 2.8759e-10, 3.9555e-10, 4.8837e-08,
        1.1647e-07, 7.6402e-10, 2.0775e-10, 1.5987e-08, 8.4379e-10, 1.0612e-10,
        1.0072e-08, 4.3593e-10, 1.0037e-12, 5.3185e-11, 4.0479e-08, 1.1381e-08,
        3.8899e-13, 2.9403e-09, 5.3745e-08, 1.5199e-07, 6.0395e-12, 3.5818e-07,
        5.1228e-08, 2.0990e-07, 8.4953e-10, 5.3625e-10, 7.7983e-08, 3.2948e-10,
        4.7550e-09, 4.2167e-10, 4.0250e-

In [37]:
h

200

In [38]:
topk_idx = torch.topk(significance, k=h)[1]
topk_idx

tensor([190, 166, 172, 191, 175, 181, 153, 189, 150, 157, 185, 162, 168, 177,
        163, 184, 160, 154, 159, 158, 192, 176, 155, 173, 186, 169, 179, 187,
        195, 194, 196, 198, 161, 199, 180, 197, 156, 178, 165, 174, 152, 167,
        151, 188, 171, 183, 170, 182, 193, 164, 122, 275, 110, 232,  14, 239,
        277, 121, 127, 145, 125,  96,  65,  30, 285, 102,  89, 103, 210,  21,
        267, 236, 262,  67,  99, 234, 288, 292, 230, 124,  91,  63,  19, 104,
         76, 218,  48, 272, 100, 113, 284, 260,  98,  88, 112,  26, 205,  79,
         93,  70, 243, 138,  78, 140,  22,  81,  34, 109, 276, 211, 270,  36,
         62, 287, 146, 246,  66, 223,  47,  44,  15, 212, 123,  35,  11,  17,
        224, 283,  58,  74, 242,  38,  95, 273, 229, 134, 214, 245, 131, 106,
         40,   4, 257, 139, 144, 208, 206, 259, 264, 204, 235, 241, 279,  12,
        247, 149, 107,  51, 105, 253, 278,  80, 240, 297, 298,  24, 286,  31,
        248,  59,  18,  54, 263, 132, 219, 252, 269, 126, 261, 2

In [39]:
def remove_neurons_from_model(model, importance, num_prune):
    N = model[0].centers.shape[0]
    topk_idx = torch.topk(significance, k=N-num_prune, largest=True)[1]
    
    c = model[0].centers.data[topk_idx]
    v = model[-1].weight.data[:,topk_idx]
    s = model[0].scaler.data[:,topk_idx]
    model[0].centers.data = c
    model[-1].weight.data = v
    model[0].scaler.data = s
    pass

In [40]:
remove_neurons_from_model(model, significance, N_search)

In [41]:
test_acc3 = test(0, model)
test_acc3, test_acc2, test_acc

[Test] 0 Loss: 1.179 | Acc: 65.840 6584/10000


(65.84, 70.09, 69.35)

## Do this in Loop

In [42]:
add_neurons_to_model(model, *get_random_training_samples(N_search))

In [43]:
significance = torch.zeros(model[0].centers.shape[0])

forw_hook = model[0].register_forward_hook(capture_outputs)
back_hook = model[0].register_backward_hook(capture_gradients)

In [44]:
# optim = torch.optim.Adam(model.parameters())

In [45]:
for xx, yy in train_loader:
    xx = xx.to(device).view(-1, 28*28)
    yout = model(xx)
    
    none_grad()
    
#     grad = torch.randn_like(yout)
#     grad = grad/torch.norm(grad, dim=1, keepdim=True)
#     yout.backward(gradient=grad)
    
    criterion(yout, yy).backward()
    
    with torch.no_grad():
        significance += torch.sum((outputs*gradients)**2, dim=0)



In [46]:
yout.shape

torch.Size([50, 10])

In [47]:
# grad.shape

In [48]:
remove_hook()

In [49]:
remove_neurons_from_model(model, significance, N_search)

In [50]:
test_acc3 = test(0, model)
test_acc3, test_acc2, test_acc

[Test] 0 Loss: 1.634 | Acc: 57.200 5720/10000


(57.2, 70.09, 69.35)

In [51]:
asdasd

NameError: name 'asdasd' is not defined