In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [None]:
from tqdm import tqdm
import os, time, sys
import json

In [None]:
import dtnnlib as dtnn

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="../../_Datasets/", train=False, download=True, transform=mnist_transform)
# train_dataset = datasets.MNIST(root="../../../_Datasets/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [None]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## 1 Layer epsilon Softmax MLP

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*1))
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        self.epsilon = epsilon
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            dists = torch.cat([dists, torch.ones(len(x), 1, dtype=x.dtype)*self.epsilon], dim=1)
        
        ### normalize similar to UMAP
        dists = dists/torch.sqrt(dists.var(dim=1, keepdim=True)+1e-9)
        
        ## scale the dists
#         dists = torch.exp(-dists + self.scaler)
        dists = 1-dists*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
class LocalMLP_epsilonsoftmax(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=1.0, bias=False):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.new_hidden_dim = 0
        self.output_dim = output_dim
        
        self.layer0 = DistanceTransform_Epsilon(self.input_dim, self.hidden_dim, bias=bias, epsilon=epsilon)
        hdim = self.hidden_dim
        if epsilon is not None:
            hdim += 1
            
        self.scale_shift = dtnn.ScaleShift(hdim, scaler_init=5, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)

        self.activ = nn.ReLU()

        self.layer1 = nn.Linear(hdim, self.output_dim)
        self.temp_maximum = None 

    def forward(self, x):
        xo = self.layer0(x)
        xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        
        self.temp_maximum = xo.data

        xo = self.activ(xo)
        xo = self.layer1(xo)
        return xo

## Train Test Function

In [None]:
best_acc = -1
def test(epoch, model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    acc = 100.*correct/total
    return acc

## Helper Funcs

In [None]:
def get_random_training_samples(N):
    new_center = []
    new_labels = []
    count = 0
    for i, (xx, yy) in enumerate(train_loader):
        xx = xx.reshape(xx.shape[0], -1)
        if count+xx.shape[0] < N:
            new_center.append(xx)
            new_labels.append(yy)
            count += xx.shape[0]
        elif count >= N:
            break
        else:
            new_center.append(xx[:N-count])
            new_labels.append(yy[:N-count])
            count = N
            break

    new_center = torch.cat(new_center, dim=0)
    new_labels = torch.cat(new_labels, dim=0)
    
    weights = torch.zeros(len(new_labels), 10)
    for i in range(len(new_labels)):
        weights[i, new_labels[i]] = 1.
    
    return new_center.to(device), weights.to(device)

In [None]:
get_random_training_samples(2)

#### Calculate Neuron Significance

In [None]:
outputs, gradients = None, None
def capture_outputs(module, inp, out):
    global outputs
    outputs = out.data.cpu()

def capture_gradients(module, gradi, grado):
    global gradients
    gradients = grado[0].data.cpu()
        
forw_hook = None
back_hook = None
def remove_hook():
    back_hook.remove()
    forw_hook.remove()

In [None]:
def none_grad():
    global model
    for p in model.parameters():
        p.grad = None

In [None]:
criterion = nn.CrossEntropyLoss()

# Noisy Selection With Epsilon

In [None]:
_a = get_random_training_samples(20)[0]
torch.cdist(_a, _a).mean()

In [None]:
h = 100
model = LocalMLP_epsilonsoftmax(784, h, 10, epsilon=15.0)

In [None]:
N_search = 30
# N_search = 1

In [None]:
model.to(device)

In [None]:
## Initialization
new_center, weights = get_random_training_samples(h)
if model.layer0.epsilon is not None:
    e = torch.zeros(1, weights.shape[1])
    weights = torch.cat([weights, e], dim=0)

model.layer0.centers.data = new_center.to(device)
model.layer1.weight.data = weights.t().to(device)
# print(weights.shape)

In [None]:
test_acc = test(0, model)

In [None]:
def add_neurons_to_model(model, centers, values):
    c = torch.cat((model.layer0.centers.data, centers), dim=0)
    v = torch.cat((model.layer1.weight.data[:,:-1], values.t(), model.layer1.weight.data[:,-1:]), dim=1)
    
    model.layer0.centers.data = c
    model.layer1.weight.data = v

    if model.layer0.bias is not None:
        s = torch.cat([model.layer0.bias.data[:,:-1], torch.ones(1, len(centers))*0, model.layer0.bias.data[:,-1:]], dim=1)
        model.layer0.bias.data = s

    pass

In [None]:
# add_neurons_to_model(model, *get_random_training_samples(N_search))

In [None]:
# model.layer0.centers.data.shape, model.layer1.weight.data.shape

In [None]:
def remove_neurons_from_model(model, importance, num_prune):
    N = model.layer0.centers.shape[0]
    importance = importance[:N]
    topk_idx = torch.topk(importance, k=N-num_prune, largest=True)[1]
    removing = torch.topk(importance, k=num_prune, largest=False)[1]
    print(f"Removing:\n{removing.data.sort()[0]}")
    
    c = model.layer0.centers.data[topk_idx]
    ## modifying for value tensor and bias (for epsilon value)
    topk_idx = torch.cat([topk_idx, torch.tensor([N], dtype=topk_idx.dtype)])
    v = model.layer1.weight.data[:,topk_idx]
    model.layer0.centers.data = c
    model.layer1.weight.data = v
    
    if model.layer0.bias is not None:
        s = model.layer0.bias.data[:,topk_idx]
        model.layer0.bias.data = s
    pass

In [None]:
significance = torch.zeros(model.layer0.centers.shape[0]+1)

forw_hook = model.softmax.register_forward_hook(capture_outputs)
back_hook = model.softmax.register_backward_hook(capture_gradients)

for xx, yy in tqdm(train_loader):
    xx = xx.to(device).view(-1, 28*28)
    ## Rescale the values to unit norm
#     model.layer1.weight.data /= model.layer1.weight.data.norm(dim=0, keepdim=True)

    yout = model(xx)

    none_grad()
#     yout.register_hook(lambda grad: grad/(torch.norm(grad, dim=1, keepdim=True)+1e-9))
    ####################################
#     grad = torch.randn_like(yout)
#     ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
#     yout.backward(gradient=grad)
    ###################################
    loss = criterion(yout, yy)
    loss.backward()
    with torch.no_grad():
        significance += torch.sum((outputs*gradients)**2, dim=0)
        
remove_hook()
none_grad()

significance.shape

In [None]:
outputs, gradients = None, None

In [None]:
remove_neurons_from_model(model, significance, N_search)

In [None]:
model.layer0.centers.data.shape, model.layer1.weight.data.shape

In [None]:
test_acc = test(0, model)

### Redo Exp

In [None]:
h = 100
model = LocalMLP_epsilonsoftmax(784, h, 10, epsilon=15.0)

N_search = 30
# N_search = 1

model.to(device)

## Initialization
new_center, weights = get_random_training_samples(h)
if model.layer0.epsilon is not None:
    e = torch.zeros(1, weights.shape[1])
    weights = torch.cat([weights, e], dim=0)

model.layer0.centers.data = new_center.to(device)
model.layer1.weight.data = weights.t().to(device)


accs_tup = [[test(0, model), "init"]]

In [None]:
## Run multiple times for convergence
EPOCHS = 30 # 10

for s in range(EPOCHS):
    model.train()
    print(f"Adding, Finetuening and Pruning for STEP: {s}")
    ### Resetting optimizer every removal of neuron
#     optimizer = torch.optim.Adam(params, lr=learning_rate)
    
    c, v = get_random_training_samples(N_search)
    add_neurons_to_model(model, c, v)
    
    accs_tup += [[test(0, model), "add"]]
    
    significance = torch.zeros(model.layer0.centers.shape[0]+1)

    forw_hook = model.softmax.register_forward_hook(capture_outputs)
    back_hook = model.softmax.register_backward_hook(capture_gradients)
    
    for xx, yy in tqdm(train_loader):
        xx = xx.to(device).view(-1, 28*28)
        ## Rescale the values to unit norm
#         model.layer1.weight.data /= model.layer1.weight.data.norm(dim=0, keepdim=True)
        
        yout = model(xx)

        none_grad()
#         yout.register_hook(lambda grad: grad/(torch.norm(grad, dim=1, keepdim=True)+1e-9))
        ####################################
#         grad = torch.randn_like(yout)
#         ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
#         yout.backward(gradient=grad)
        ###################################
        loss = criterion(yout, yy)
        loss.backward()
        with torch.no_grad():
            significance += torch.sum((outputs*gradients)**2, dim=0)
#             significance += torch.sum(outputs*gradients, dim=0) ## This does not converge well...
            
#         optimizer.step()

    remove_hook()
    remove_neurons_from_model(model, significance, N_search)
    
    accs_tup += [[test(0, model), "prune"]]

In [None]:
# accs_tup_ = accs_tup[:21]
accs_tup_ = accs_tup

In [None]:
data = [acc for acc, label in accs_tup_]
plt.plot(data, linestyle='dashed', zorder = -1, color='pink')
                
markers = []
for i, (acc, label) in enumerate(accs_tup_):
    if label=="init":
        marker = "o"
        c = 'b'
    elif label == "add":
        marker = "+"
        c= 'g'
    else:
        marker = '_'
        c='orange'
                
    plt.scatter(i, acc, marker=marker,  lw=4, color=c, s=100)
plt.xlabel("noisy center search")
plt.ylabel("accuracy")
# plt.savefig("./outputs/19_noisy_search_fMNIST.pdf", bbox_inches="tight")

In [None]:
_, axs = plt.subplots(5, 5, figsize=(10, 10))
axs = axs.flatten()

c = model.layer0.centers.data.cpu().numpy().reshape(-1, 28,28)
# imgs = c[:len(axs)]
imgs = c[-len(axs):]

for img, ax in zip(imgs, axs):
    ax.imshow(img)
    ax.set_axis_off()

plt.show()

## Noisy Selection + Finetuening (without epsilon)

In [None]:
def add_neurons_to_model(model, centers, values):
    c = torch.cat((model.layer0.centers.data, centers), dim=0)
    v = torch.cat((model.layer1.weight.data, values.t()), dim=1)
    s = torch.cat([model.layer0.bias.data, torch.ones(1, len(centers))*0], dim=1)

    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [None]:
h = 100
model = LocalMLP_epsilonsoftmax(784, h, 10, epsilon=None, bias=True)

In [None]:
model.to(device)

In [None]:
# add_neurons_to_model(model, *get_random_training_samples(N_search))

In [None]:
model.layer0.centers.data.shape, model.layer1.weight.data.shape

In [None]:
def remove_neurons_from_model(model, importance, num_prune):
    N = model.layer0.centers.shape[0]
    topk_idx = torch.topk(importance, k=N-num_prune, largest=True)[1]
    removing = torch.topk(importance, k=num_prune, largest=False)[1]
    print(f"Removing:\n{removing.data.sort()[0]}")
    c = model.layer0.centers.data[topk_idx]
    v = model.layer1.weight.data[:,topk_idx]
    s = model.layer0.bias.data[:,topk_idx]
    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [None]:
# remove_neurons_from_model(model, significance, N_search)

In [None]:
N_search = 30
# N_search = 1

In [None]:
new_center, weights = get_random_training_samples(h)
model.layer0.centers.data = new_center.to(device)
model.layer1.weight.data = weights.t().to(device)

In [None]:
test_acc = test(0, model)

In [None]:
learning_rate = 0.01

p1, p2 = [], []
for p in model.named_parameters():
    if p[0] == "layer0.centers":
        p1.append(p[1])
    else:
        p2.append(p[1])

params = [
    {"params": p1, "lr": learning_rate*0.03}, ## default - to change little from data point
#     {"params": p1},
    {"params": p2},
]

optimizer = torch.optim.Adam(params, lr=learning_rate)

In [None]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
## Run multiple times for convergence
EPOCHS = 30

model.train()
for s in range(EPOCHS):
    print(f"Adding, Finetuening and Pruning for STEP: {s}")
    ### Resetting optimizer every removal of neuron
#     optimizer = torch.optim.Adam(params, lr=learning_rate)
    
    c, v = get_random_training_samples(N_search)
#     v *= model.layer1.weight.data.max(dim=0)[0].mean()
    v *= (model.layer1.weight.data.max() + model.layer1.weight.data.max(dim=0)[0].mean())/2
    add_neurons_to_model(model, c, v)
    
    significance = torch.zeros(model.layer0.centers.shape[0])

    forw_hook = model.softmax.register_forward_hook(capture_outputs)
    back_hook = model.softmax.register_backward_hook(capture_gradients)
    
    for xx, yy in tqdm(train_loader):
        xx = xx.to(device).view(-1, 28*28)
        ## Rescale the values to unit norm
#         model.layer1.weight.data /= model.layer1.weight.data.norm(dim=0, keepdim=True)
        
        yout = model(xx)

        none_grad()
#         yout.register_hook(lambda grad: grad/(torch.norm(grad, dim=1, keepdim=True)+1e-9))
        ####################################
#         grad = torch.randn_like(yout)
#         ### grad = grad/torch.norm(grad, dim=1, keepdim=True)
#         yout.backward(gradient=grad)
        ###################################
        loss = criterion(yout, yy)
        loss.backward()
        with torch.no_grad():
            significance += torch.sum((outputs*gradients)**2, dim=0)
#             significance += torch.sum(outputs*gradients, dim=0) ## Does not converge well
            
        optimizer.step()

    remove_hook()
    remove_neurons_from_model(model, significance, N_search)
    test_acc3 = test(0, model)
#     print(f"Accuracy: {test_acc3}")

## Finetune after finishing removal to get better performance ??

In [None]:
_, axs = plt.subplots(5, 5, figsize=(10, 10))
axs = axs.flatten()

c = model.layer0.centers.data.cpu().numpy().reshape(-1, 28,28)
# imgs = c[:len(axs)]
imgs = c[-len(axs):]

for img, ax in zip(imgs, axs):
    ax.imshow(img)
    ax.set_axis_off()

plt.show()

In [None]:
model.layer1.weight.data.max(dim=0)[0], model.layer1.weight.data.max(dim=0)[0].mean()

In [None]:
model.layer1.weight.data.max(), model.layer1.weight.data.max(dim=0)[0].mean()