In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm
# from sklearn import datasets

In [2]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [3]:
from tqdm import tqdm
import os, time, sys
import json

In [4]:
import dtnnlib as dtnn

In [5]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

# train_dataset = datasets.FashionMNIST(root="../../../_Datasets/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.FashionMNIST(root="../../../_Datasets/", train=False, download=True, transform=mnist_transform)
train_dataset = datasets.MNIST(root="../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.MNIST(root="../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [6]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [7]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50])


## 1 Layer epsilon Softmax MLP

In [8]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*1))
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        self.epsilon = epsilon
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            dists = torch.cat([dists, torch.ones(len(x), 1, dtype=x.dtype)*self.epsilon], dim=1)
        
        ### normalize similar to UMAP
        dists = dists/torch.sqrt(dists.var(dim=1, keepdim=True)+1e-9)
        
        ## scale the dists
#         dists = torch.exp(-dists + self.scaler)
        dists = 1-dists*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [9]:
class LocalMLP_epsilonsoftmax(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=1.0):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.new_hidden_dim = 0
        self.output_dim = output_dim
        
        self.layer0 = DistanceTransform_Epsilon(self.input_dim, self.hidden_dim, bias=True, epsilon=epsilon)
        hdim = self.hidden_dim
        if epsilon is not None:
            hdim += 1
            
#         self.scale_shift = dtnn.ScaleShift(hdim, scaler_init=3, shifter_init=0, scaler_const=True, shifter_const=True)
        self.scale_shift = dtnn.ScaleShift(hdim, scaler_init=5, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)

#         self.activ = dtnn.OneActiv(hdim, mode='relu', beta_init=np.log(1.2))
        self.activ = nn.ReLU()

        self.layer1 = nn.Linear(hdim, self.output_dim)
        
    def forward(self, x):
        xo = self.layer0(x)
        xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        
        xo = self.activ(xo)
        xo = self.layer1(xo)
        return xo

In [10]:
h = 100
model = LocalMLP_epsilonsoftmax(784, h, 10, epsilon=None)

In [11]:
model.to(device)

LocalMLP_epsilonsoftmax(
  (layer0): DistanceTransform_Epsilon()
  (scale_shift): ScaleShift()
  (softmax): Softmax(dim=-1)
  (activ): ReLU()
  (layer1): Linear(in_features=100, out_features=10, bias=True)
)

## Randomly

In [12]:
N = model.layer0.centers.shape[0]
new_center = []
new_labels = []
count = 0
for i, (xx, yy) in enumerate(train_loader):
    xx = xx.reshape(-1, model.layer0.input_dim).to(model.layer0.centers.device)
    if count+xx.shape[0] < N:
        new_center.append(xx)
        new_labels.append(yy)
        count += xx.shape[0]
    elif count >= N:
        break
    else:
        new_center.append(xx[:N-count])
        new_labels.append(yy[:N-count])
        count = N
        break
        
new_center = torch.cat(new_center, dim=0)
new_labels = torch.cat(new_labels, dim=0)

In [13]:
new_center.shape, new_labels.shape

(torch.Size([100, 784]), torch.Size([100]))

## Set


In [14]:
weights = torch.zeros(len(new_labels), 10)
for i in range(len(new_labels)):
    weights[i, new_labels[i]] = 1.
# weights

In [15]:
weights.shape

torch.Size([100, 10])

In [16]:
model.layer0.centers.data = new_center.to(model.layer0.centers.device)
model.layer1.weight.data = weights.t().to(model.layer1.weight.data)

In [17]:
best_acc = -1
def test(epoch, model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    acc = 100.*correct/total
    return acc

In [18]:
model.eval()

LocalMLP_epsilonsoftmax(
  (layer0): DistanceTransform_Epsilon()
  (scale_shift): ScaleShift()
  (softmax): Softmax(dim=-1)
  (activ): ReLU()
  (layer1): Linear(in_features=100, out_features=10, bias=True)
)

In [19]:
criterion = nn.CrossEntropyLoss()

In [20]:
test_acc = test(0, model)
test_acc

[Test] 0 Loss: 1.786 | Acc: 68.580 6858/10000


68.58

## Add new centers to the model

In [21]:
N_search = 30

In [22]:
def get_random_training_samples(N):
    new_center = []
    new_labels = []
    count = 0
    for i, (xx, yy) in enumerate(train_loader):
        xx = xx.reshape(xx.shape[0], -1)
        if count+xx.shape[0] < N:
            new_center.append(xx)
            new_labels.append(yy)
            count += xx.shape[0]
        elif count >= N:
            break
        else:
            new_center.append(xx[:N-count])
            new_labels.append(yy[:N-count])
            count = N
            break

    new_center = torch.cat(new_center, dim=0)
    new_labels = torch.cat(new_labels, dim=0)
    
    weights = torch.zeros(len(new_labels), 10)
    for i in range(len(new_labels)):
        weights[i, new_labels[i]] = 1.
    
    return new_center.to(device), weights.to(device)

In [23]:
get_random_training_samples(2)

(tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]),
 tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]]))

In [24]:
def add_neurons_to_model(model, centers, values):
    c = torch.cat((model.layer0.centers.data, centers), dim=0)
    v = torch.cat((model.layer1.weight.data, values.t()), dim=1)
    s = torch.cat([model.layer0.bias.data, torch.ones(1, len(centers))*0], dim=1)

    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [25]:
model.layer0.centers.data.shape, model.layer1.weight.data.shape

(torch.Size([100, 784]), torch.Size([10, 100]))

In [26]:
add_neurons_to_model(model, *get_random_training_samples(N_search))

In [27]:
model.layer0.centers.data.shape, model.layer1.weight.data.shape

(torch.Size([130, 784]), torch.Size([10, 130]))

In [28]:
test_acc2 = test(0, model)
test_acc2, test_acc

[Test] 0 Loss: 1.767 | Acc: 70.250 7025/10000


(70.25, 68.58)

## Calculate Neuron Significance

In [29]:
outputs, gradients = None, None
def capture_outputs(module, inp, out):
    global outputs
#     print(inp)
    outputs = out.data.cpu()

def capture_gradients(module, gradi, grado):
    global gradients
#     print(gradi, '\n')
#     print(grado)
    gradients = grado[0].data.cpu()
        
forw_hook = model.softmax.register_forward_hook(capture_outputs)
back_hook = model.softmax.register_backward_hook(capture_gradients)
# back_hook = model[0].register_full_backward_hook(capture_gradients)


def remove_hook():
    back_hook.remove()
    forw_hook.remove()

In [30]:
significance = torch.zeros(model.layer0.centers.shape[0])

In [31]:
for xx, yy in train_loader:
    xx, yy = xx.to(device).view(-1, 28*28), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 784]) torch.Size([50])


In [32]:
yout = model(xx)

In [33]:
list(model.parameters())

[Parameter containing:
 tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]], requires_grad=True),
 Parameter containing:
 tensor([[0.]], requires_grad=True),
 Parameter containing:
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [34]:
def none_grad():
    for p in model.parameters():
        p.grad = None

In [35]:
none_grad()
yout.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))

# grad = torch.randn_like(yout)
# ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
# yout.backward(gradient=grad, retain_graph=False)

criterion(yout, yy).backward()

In [36]:
remove_hook()

In [37]:
outputs.shape, gradients.shape

(torch.Size([50, 130]), torch.Size([50, 130]))

In [38]:
with torch.no_grad():
    significance += torch.sum((outputs*gradients)**2, dim=0)
significance

tensor([1.0843e-05, 2.7837e-06, 1.3304e-03, 7.9738e-06, 6.5565e-02, 3.9431e-02,
        5.1774e-04, 2.4698e-08, 1.0038e-03, 1.6180e-06, 8.9945e-01, 1.4323e-02,
        1.0667e-05, 2.9981e-02, 2.3716e-08, 3.0156e-03, 1.2143e-06, 9.4740e-03,
        1.6931e-03, 8.9956e-01, 2.1828e-01, 8.2302e-08, 9.2242e-01, 7.5332e-04,
        5.0646e-06, 8.3960e-02, 2.1386e-03, 6.8199e-01, 3.0051e-05, 5.3785e-05,
        2.3041e-04, 2.4872e-01, 2.6056e-03, 5.1014e-01, 1.4312e-02, 4.5986e-05,
        7.6161e-04, 7.5922e-03, 4.4677e-03, 3.1681e-04, 8.0149e-01, 5.0286e-03,
        4.8214e-04, 1.3849e-02, 8.9730e-01, 7.6227e-02, 4.6019e-03, 7.4381e-06,
        4.3657e-05, 8.0637e-01, 1.1524e-04, 8.2728e-05, 2.8509e-06, 1.8518e-04,
        1.9607e-09, 6.0305e-08, 1.7996e-03, 3.6866e-01, 9.6218e-01, 1.1051e-02,
        4.0418e-02, 1.5905e-04, 1.6248e-02, 3.6764e-02, 7.3153e-01, 1.5388e-03,
        1.3211e-04, 5.9147e-01, 4.3695e-04, 1.1031e-04, 5.0696e-03, 1.6282e-01,
        2.8402e-08, 2.5376e-03, 2.9121e-

In [39]:
h

100

In [40]:
torch.topk(significance, k=h, sorted=True, largest=True)[0]

tensor([1.0382e+00, 9.6964e-01, 9.6218e-01, 9.2381e-01, 9.2242e-01, 8.9956e-01,
        8.9945e-01, 8.9910e-01, 8.9730e-01, 8.6995e-01, 8.1700e-01, 8.0637e-01,
        8.0149e-01, 7.3153e-01, 7.1337e-01, 6.8199e-01, 5.9147e-01, 5.6155e-01,
        5.1101e-01, 5.1014e-01, 5.0330e-01, 3.6866e-01, 3.6199e-01, 2.5965e-01,
        2.4872e-01, 2.1828e-01, 2.1335e-01, 1.7186e-01, 1.6282e-01, 1.3737e-01,
        1.2755e-01, 1.0344e-01, 8.3960e-02, 7.6227e-02, 6.5565e-02, 4.0418e-02,
        3.9431e-02, 3.6764e-02, 3.4498e-02, 3.2018e-02, 2.9981e-02, 2.1040e-02,
        1.6248e-02, 1.5340e-02, 1.4323e-02, 1.4312e-02, 1.3849e-02, 1.3463e-02,
        1.2450e-02, 1.2301e-02, 1.1051e-02, 9.4740e-03, 9.0918e-03, 7.5922e-03,
        5.1408e-03, 5.0696e-03, 5.0286e-03, 4.6019e-03, 4.4677e-03, 3.0235e-03,
        3.0156e-03, 2.6056e-03, 2.5376e-03, 2.1845e-03, 2.1386e-03, 2.0093e-03,
        1.8818e-03, 1.7996e-03, 1.7004e-03, 1.6931e-03, 1.5388e-03, 1.3304e-03,
        1.2417e-03, 1.0145e-03, 1.0038e-

In [41]:
topk_idx = torch.topk(significance, k=h, sorted=True)[1]
topk_idx

tensor([128, 103,  58, 129,  22,  19,  10,  89,  44,  79,  77,  49,  40,  64,
         87,  27,  67,  94,  95,  33,  80,  57, 118,  93,  31,  20, 109, 112,
         71, 106,  81,  96,  25,  45,   4,  60,   5,  63, 100, 110,  13, 111,
         62, 115,  11,  34,  43,  91, 119, 107,  59,  17, 120,  37, 105,  70,
         41,  46,  38, 101,  15,  32,  73, 124,  26,  82, 123,  56, 116,  18,
         65,   2, 127,  90,   8,  36,  23,  84,   6,  97,  42,  68,  39, 104,
         30,  92,  98,  53,  61, 125,  66,  50,  69, 108,  51,  99,  29,  35,
         48,  88])

In [42]:
def remove_neurons_from_model(model, importance, num_prune):
    N = model.layer0.centers.shape[0]
    topk_idx = torch.topk(importance, k=N-num_prune, largest=True)[1]
    
    c = model.layer0.centers.data[topk_idx]
    v = model.layer1.weight.data[:,topk_idx]
    s = model.layer0.bias.data[:,topk_idx]
    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [43]:
remove_neurons_from_model(model, significance, N_search)

In [44]:
test_acc3 = test(0, model)

test_acc3, test_acc2, test_acc

[Test] 0 Loss: 1.794 | Acc: 67.230 6723/10000


(67.23, 70.25, 68.58)

In [45]:
model.layer0.centers.data.shape, model.layer1.weight.data.shape, model.layer0.bias.data.shape

(torch.Size([100, 784]), torch.Size([10, 100]), torch.Size([1, 100]))

In [46]:
asdasd  ### ^^ expected::: test_acc2 > test_acc3 > test_acc

NameError: name 'asdasd' is not defined

## Do this in Loop

In [None]:
add_neurons_to_model(model, *get_random_training_samples(N_search))

In [None]:
significance = torch.zeros(model.layer0.centers.shape[0])

forw_hook = model.softmax.register_forward_hook(capture_outputs)
back_hook = model.softmax.register_backward_hook(capture_gradients)

In [None]:
# optim = torch.optim.Adam(model.parameters())

In [None]:
for xx, yy in train_loader:
    xx = xx.to(device).view(-1, 28*28)
    yout = model(xx)
    
    none_grad()
    yout.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))
    
#     grad = torch.randn_like(yout)
#     ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
#     yout.backward(gradient=grad)
    
#     criterion(yout, yy).backward()
    
    with torch.no_grad():
        significance += torch.sum((outputs*gradients)**2, dim=0)

In [None]:
yout.shape

In [None]:
# grad.shape

In [None]:
remove_hook()

In [None]:
remove_neurons_from_model(model, significance, N_search)

In [None]:
test_acc3 = test(0, model)
test_acc3, test_acc2, test_acc

In [None]:
asdasd  ### ^^ expected test_acc2 > test_acc3 > test_acc

## Optimize for multiple steps

In [47]:
test_acc3

67.23

In [48]:
## Run multiple times for convergence
STEPS = 10
for s in range(STEPS):
    print(f"Adding and Pruning for STEP: {s}")
    add_neurons_to_model(model, *get_random_training_samples(N_search))
    
    significance = torch.zeros(model.layer0.centers.shape[0])

    forw_hook = model.softmax.register_forward_hook(capture_outputs)
    back_hook = model.softmax.register_backward_hook(capture_gradients)
    
    for xx, yy in train_loader:
        xx = xx.to(device).view(-1, 28*28)
        yout = model(xx)

        none_grad()
        yout.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))
        ####################################
#         grad = torch.randn_like(yout)
#         ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
#         yout.backward(gradient=grad)
        ###################################
        criterion(yout, yy).backward()
        with torch.no_grad():
            significance += torch.sum((outputs*gradients)**2, dim=0)
    
    remove_hook()
    remove_neurons_from_model(model, significance, N_search)
    test_acc3 = test(0, model)
#     print(f"Accuracy: {test_acc3}")

Adding and Pruning for STEP: 0
[Test] 0 Loss: 1.717 | Acc: 75.450 7545/10000
Adding and Pruning for STEP: 1
[Test] 0 Loss: 1.702 | Acc: 77.130 7713/10000
Adding and Pruning for STEP: 2
[Test] 0 Loss: 1.687 | Acc: 78.540 7854/10000
Adding and Pruning for STEP: 3
[Test] 0 Loss: 1.673 | Acc: 79.870 7987/10000
Adding and Pruning for STEP: 4
[Test] 0 Loss: 1.678 | Acc: 79.710 7971/10000
Adding and Pruning for STEP: 5
[Test] 0 Loss: 1.676 | Acc: 79.500 7950/10000
Adding and Pruning for STEP: 6
[Test] 0 Loss: 1.671 | Acc: 80.110 8011/10000
Adding and Pruning for STEP: 7
[Test] 0 Loss: 1.669 | Acc: 80.190 8019/10000
Adding and Pruning for STEP: 8
[Test] 0 Loss: 1.670 | Acc: 80.130 8013/10000
Adding and Pruning for STEP: 9
[Test] 0 Loss: 1.665 | Acc: 80.730 8073/10000


## Noisy Selection + Finetuening

In [49]:
"""
PROBLEM 1:The neuron that does not get pruned gets trained for longer,, 
    hence can drift largely from its initialized data point (even at lower learning rate).
    - Can freeze the centers of the MLP and train only values.

PROBLEM 2:The values of each neuron might fire at different magnitude bringing different amount of
    importance for classification (even the distance of center with other centers reduces its magnitude).
    - This should be carefully handeled at initialization (or normalizing the values to unit norm).
"""
print()




## Multilayer Noisy Selection

In [50]:
"""
TYPE 1: DT>eSM>DT>eSM>V
Type 2:  /DT>eSM>S\
        X---------+\>eSM>V
"""
print()


