In [None]:
import numpy as np
import copy

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import dtnnlib as dtnn

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (9, 8)

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

## Building 2D model

In [None]:
np.random.seed(1)
def twospirals(n_points, noise=.5, angle=784):
    """
     Returns the two spirals dataset.
    """
    n = np.sqrt(np.random.rand(n_points,1)) * angle * (2*np.pi)/360
    d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise
    d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise
    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), 
            np.hstack((np.zeros(n_points),np.ones(n_points))))

In [None]:
x, y = twospirals(300, angle=560)
x, y = x/x.max(axis=0, keepdims=True), y.reshape(-1)
xx, yy = torch.FloatTensor(x), torch.FloatTensor(y.reshape(-1,1))

x1 = xx[:,0]
x2 = xx[:,1]

%matplotlib inline
plt.figure(figsize=(5,5))
plt.scatter(x1, x2, c=y, marker='.')
# plt.savefig("./clf_toy_data.pdf")
plt.xlabel("x1")
plt.ylabel("x2")
plt.axis("equal")
plt.show()

In [None]:
xx, yy = xx.to(device), yy.to(device)

## Distance Based Classification

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*1))

        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        self.epsilon = epsilon
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            dists = torch.cat([dists, torch.ones(len(x), 1, dtype=x.dtype)*self.epsilon], dim=1)
        
        dists = -dists
        dists = dists/np.sqrt(dists.shape[1])
        dists = dists*torch.exp(self.scaler)

        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
class DT_epsilon_Classifier(DistanceTransform_Epsilon):
    
    def __init__(self, input_dim, output_dim, bias=True, epsilon=1.0):
        super().__init__(input_dim, output_dim, bias=bias, epsilon=epsilon)
        
    def forward(self, x):
        xo = super().forward(x)
#         return xo[:, :-1]
        return F.softmax(xo, dim=-1)

In [None]:
def log_nll_loss(output, target):
    return F.nll_loss(torch.log(output), target)

In [None]:
yy = yy.reshape(-1).type(torch.LongTensor)

#### DTeSM Residual 

In [None]:
class DTeSM(DistanceTransform_Epsilon):
    
    def __init__(self, input_dim, output_dim, bias=True, epsilon=1.0, itemp=10):
        ### NOTE: Here, not using bias leads to more uniform centroid activation, and easy to compare..
        super().__init__(input_dim, output_dim, bias=bias, epsilon=epsilon)
        
        self.scale_shift = dtnn.ScaleShift(-1, scaler_init=itemp, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)
        self.temp_activ = None
        
    def forward(self, x):
        xo = super().forward(x)
        xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        self.temp_activ = xo.data
#         return xo[:, :-1]
        return xo

In [None]:
class LocalMLP_DTeSM(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=None, itemp=1.0):
        super().__init__()
        self.layer0 = DTeSM(input_dim, hidden_dim, True, epsilon, itemp)
        if epsilon is not None:
            hidden_dim += 1
        self.layer1 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        return x

In [None]:
class LocalResidual_DTeSM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, epsilon=None, itemp=1.0):
        super().__init__()
        self.layer0 = DTeSM(input_dim, hidden_dim, True, epsilon, itemp)
        if epsilon is not None:
            hidden_dim += 1
        self.layer1 = nn.Linear(hidden_dim, input_dim)
        self.layer1.bias.data *= 0.
        
    def forward(self, x):
        h = self.layer0(x)
        h = x + self.layer1(h)
        return h

## Optimize Classifier to Data

In [None]:
model = DT_epsilon_Classifier(2, 2, epsilon=1.0)
# model = DT_epsilon_Classifier(2, 2, epsilon=None)

In [None]:
criterion = log_nll_loss

In [None]:
### ASSUMING: first half is class 0 and second half is class 1
cls_randidx = torch.randint(len(yy)//2, (2,))+torch.LongTensor([0, len(yy)//2])
model.centers.data = xx[cls_randidx]

In [None]:
yout = model(xx)
loss = criterion(yout, yy)

accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()                
print(f'Acc:{float(accuracy):.2f}, Loss:{float(loss)}')

In [None]:
loss_bkp = float(loss)

In [None]:
ax = plt.figure(figsize=(6,6)).add_subplot()
out = yout.max(dim=1)[1].data.cpu().numpy()
ax.scatter(x1, x2, c=out, marker= '.')
## plot centroids
c = model.centers.data.cpu()
ax.scatter(c[:,0], c[:,1], ec='k', fc='r', marker= 'X', s=100)
plt.show()

In [None]:
#############################
### Development of replacing centers per class
STEPS = 100
for steps in range(STEPS):
    for i in range(2):## for each class
        backup_center = copy.deepcopy(model.centers.data)
        cls_randidx = torch.randint(len(yy)//2, (1,))[0] + i*len(yy)//2
        model.centers.data[i] = xx[cls_randidx]

        yout = model(xx)
        loss = criterion(yout, yy)
        accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()                
        if loss > loss_bkp:
            model.centers.data = backup_center
        else:
            loss_bkp = float(loss)
        
    if (steps+1)%10 == 0:
        yout = model(xx)
        loss = criterion(yout, yy)
        accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()                
        ax = plt.figure(figsize=(6,6)).add_subplot()
        out = yout.max(dim=1)[1].data.cpu().numpy()
        ax.scatter(x1, x2, c=out, marker= '.')
        ## plot centroids
        c = model.centers.data.cpu()
        ax.scatter(c[:,0], c[:,1], ec='k', fc='r', marker= 'X', s=100)
        plt.show()
        print(f'Acc:{float(accuracy):.2f}, Loss:{float(loss)}')

In [None]:
num_points = 1000
X1 = np.linspace(-1.5, 1.5, num_points)*2
X2 = np.linspace(-1.5, 1.5, num_points)*2
X1, X2 = np.meshgrid(X1, X2)

XX = torch.Tensor(np.c_[X1.reshape(-1), X2.reshape(-1)]).to(device)
XX.shape

### Repeat below after changing parameters

In [None]:
YY = model(XX)
YY = YY.reshape(num_points, num_points, -1)

In [None]:
yout = model(xx)
out = yout.max(dim=1)[1].data.cpu().numpy()
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

In [None]:
max_actv = model(model.centers.data).data.cpu()
max_actv = max_actv.diag().numpy()
max_actv

In [None]:
for idx in range(YY.shape[-1]):
    conf = YY[:,:,idx]
    conf = conf.data.cpu().numpy().reshape(X1.shape)
    
    ax = plt.figure(figsize=(6,6)).add_subplot()
    ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

    ## plot centroids
    c = model.centers.data.cpu()
    for i in range(c.shape[0]):
        color = matplotlib.cm.tab20(i%20)
        ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)
    
    try:
        ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
        print(f"center:",max_actv[idx],"max_grid:",conf.max(), max_actv[idx] >= conf.max())
    except:
        pass
    
    maxpt = XX[conf.argmax()]
    ax.scatter(maxpt[0], maxpt[1], color="r", marker= 'o', s=100)
    
    plt.imshow(conf, interpolation='nearest',
           extent=(X1.min(), X1.max(), X2.min(), X2.max()),
           alpha=0.6, cmap='gray',
           aspect='auto', origin='lower')
    
    LVLs = 20
#     LVLs = torch.linspace(0.0, 0.99, 20)
    cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
    ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")
    
    plt.show()

In [None]:
model.bias.data, model.scaler, torch.exp(model.scaler)

##### Changing parameters

In [None]:
# model.bias.data[0] = torch.Tensor([0, 0, -1])

In [None]:
model.scaler.data[0,0] = 2.

In [None]:
classifier = model

## Optimize Local Res-MLP center noisy

In [None]:
tt_0 = classifier.centers.data[yy]
tt_0

In [None]:
H0 = 20
N_search0 = 1

In [None]:
residual0 = LocalResidual_DTeSM(2, H0, epsilon=0.4, itemp=7.0)

In [None]:
residual0.layer0.scaler.data[0,0] = 2.

In [None]:
residual0.layer1.weight.data[:, -1] = residual0.layer1.weight.data[:, -1]*0 ### zero out epsilon..

In [None]:
residual0

In [None]:
residual0.layer1.bias.data

In [None]:
## random init
randidx = torch.randperm(len(xx))[:H0]
residual0.layer0.centers.data = xx[randidx] 

diff = tt_0[randidx] - residual0.layer0.centers.data - residual0.layer1.bias.data
residual0.layer1.weight.data[:, :H0] = diff.t()

In [None]:
#### Visualize neurons -- code below 

[Visualize Neurons (jump to code)](#Visualize-Neurons)

In [None]:
# asdasdasd

### Add Neurons

In [None]:
def add_neurons_to_residual(model, centers, values):
    c = torch.cat((model.layer0.centers.data, centers), dim=0)
    if model.layer0.epsilon is None:
        v = torch.cat((model.layer1.weight.data, values.t()), dim=1)
        s = torch.cat([model.layer0.bias.data, torch.ones(1, len(centers))*0], dim=1)
    else:
        v = torch.cat((model.layer1.weight.data[:,:-1], values.t(), model.layer1.weight.data[:,-1:]), dim=1)
        s = torch.cat([model.layer0.bias.data[:,:-1], torch.ones(1, len(centers))*0, model.layer0.bias.data[:,-1:]], dim=1)
        
    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [None]:
randidx = torch.randperm(len(xx))[:N_search0]

In [None]:
shift_by = tt_0[randidx] - xx[randidx] - residual0.layer1.bias.data
# add_neurons_to_residual(residual0, xx[randidx], shift_by)

In [None]:
with torch.no_grad():
    h1 = residual0(xx)
    yout = classifier(h1)
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

[Visualize Neurons (jump to code)](#Visualize-Neurons)

### Prune Neurons

In [None]:
def remove_neurons_from_residual(model, importance, num_prune):
    N = model.layer0.centers.shape[0]
    importance = importance[:N]
    topk_idx = torch.topk(importance, k=N-num_prune, largest=True)[1]
    removing = torch.topk(importance, k=num_prune, largest=False)[1]
    print(f"Removing:\n{removing.data.sort()[0]}")
    
    c = model.layer0.centers.data[topk_idx]
    ## modifying for value tensor and bias (for epsilon value)
    if model.layer0.epsilon is not None:
        topk_idx = torch.cat([topk_idx, torch.tensor([N], dtype=topk_idx.dtype)])
    s = model.layer0.bias.data[:,topk_idx]
    v = model.layer1.weight.data[:,topk_idx]
    model.layer0.centers.data = c
    model.layer1.weight.data = v
    model.layer0.bias.data = s
    pass

In [None]:
class ImportanceEstimator:
    
    def __init__(self, module):
        self.module = module
        self.outputs = None
        self.gradients = None
        self.back_hook = None
        self.forw_hook = None
        self.significance = None
        self.reset_significance()
        
    def reset_significance(self):
        _N = self.module.centers.shape[0]
        if self.module.epsilon is not None:
            _N += 1
        self.significance = torch.zeros(_N)
        
    def accumulate_significance(self):
        with torch.no_grad():
            self.significance += torch.sum((self.outputs*self.gradients)**2, dim=0)
        
    def capture_outputs(self, module, inp, out):
        self.outputs = out.data.cpu()

    def capture_gradients(self, module, gradi, grado):
        self.gradients = grado[0].data.cpu()
        
    def attach_hook(self):
        self.forw_hook = self.module.softmax.register_forward_hook(self.capture_outputs)
        self.back_hook = self.module.softmax.register_backward_hook(self.capture_gradients)
        
    def remove_hook(self):
        self.back_hook.remove()
        self.forw_hook.remove()

In [None]:
def none_grad(model):
    for p in model.parameters():
        p.grad = None

In [None]:
imp_est = ImportanceEstimator(residual0.layer0)

In [None]:
mse_loss = nn.MSELoss()

In [None]:
none_grad(residual0)

imp_est.attach_hook()
h1 = residual0(xx)
h1.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))
####################################
#         grad = torch.randn_like(yout)
#         yout.backward(gradient=grad)
###################################
mse_loss(h1, tt_0).backward()
imp_est.accumulate_significance()

imp_est.remove_hook()

In [None]:
imp_est.significance

In [None]:
# remove_neurons_from_residual(residual0, imp_est.significance, N_search0)

In [None]:
with torch.no_grad():
    h1 = residual0(xx)
    yout = classifier(h1)
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

In [None]:
print(yout.shape)

In [None]:
with torch.no_grad():
    h1 = residual0(xx)
    yout = classifier(h1)
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

In [None]:
accs_tup = [[accuracy, "init"]]

### Optimize Iteratively

In [None]:
STEPS = 200
for step in range(STEPS):
    ## Add
    randidx = torch.randperm(len(xx))[:N_search0]
    shift_by = tt_0[randidx] - xx[randidx] - residual0.layer1.bias.data
    add_neurons_to_residual(residual0, xx[randidx], shift_by)
    with torch.no_grad():
        h1 = residual0(xx)
        yout = classifier(h1)
    accuracy_add = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
    
    accs_tup += [[accuracy_add, "add"]]    
    
    ## Prune
    none_grad(residual0)
    imp_est.reset_significance()
    imp_est.attach_hook()
    
    h1 = residual0(xx)
#     h1.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))
    ####################################
#     grad = torch.randn_like(h1)
#     h1.backward(gradient=grad)
    ###################################
    mse_loss(h1, tt_0).backward()
#     log_nll_loss(classifier(h1), yy).backward()

    ###################################
    imp_est.accumulate_significance()
    imp_est.remove_hook()
    
    remove_neurons_from_residual(residual0, imp_est.significance, N_search0)
    with torch.no_grad():
        h1 = residual0(xx)
        yout = classifier(h1)
        loss = criterion(yout, yy)
    accuracy_prune = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
    
    accs_tup += [[accuracy_prune, "prune"]]    
    
    
    print(f'Step:{step}, AccAdd:{float(accuracy_add):.2f}, Acc:{float(accuracy_prune):.2f}, Loss:{float(loss):.3f}')

In [None]:
imp_est.significance

### Visualize Neurons

In [None]:
h1 = residual0(xx)
yout = classifier(h1)
out = yout.max(dim=1)[1].data.cpu().numpy()
## centroids and shift
c = residual0.layer0.centers.data.cpu()
d = residual0.layer1.weight.data.cpu().t() #+ .cpu()

In [None]:
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

In [None]:
max_actv = residual0.layer0(residual0.layer0.centers.data).data.cpu()
max_actv = max_actv.diag()#.numpy()
max_actv

In [None]:
ax = plt.figure(figsize=(6,5)).add_subplot()
ax.scatter(h1.data[:,0], h1.data[:,1], c=yy, marker= '.', alpha=0.3)

ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3, cmap="coolwarm")

for i in range(c.shape[0]):
    color = matplotlib.cm.tab10(i%20)
    ax.arrow(c[i,0], c[i,1], d[i,0], d[i,1], head_width=0.15, head_length=0.1, fc=color, ec=color, linestyle=(0, (5, 10)))
    ax.scatter(c[i,0], c[i,1], color=color, marker= 'x')
    
color = "k"
ax.arrow(0, 0, d[len(c),0], d[len(c),1], head_width=0.15, head_length=0.1, fc="k", ec="k", linestyle=(0, (5, 10)), linewidth=2.0)
    
plt.show()

In [None]:
residual0.layer0.bias.data, residual0.layer0.scaler, torch.exp(residual0.layer0.scaler)

#### Visualize residual-layer neuron's activation region

In [None]:
residual0(XX)
YY = residual0.layer0.temp_activ
YY = YY.reshape(num_points, num_points, -1)
YY.shape

In [None]:
max_actv = residual0.layer0(residual0.layer0.centers.data).data.cpu().diag()
max_actv_ = max_actv.numpy()
max_actv_

In [None]:
for idx in range(YY.shape[-1]):
    conf = YY[:,:,idx]
    conf = conf.data.cpu().numpy().reshape(X1.shape)
    
    ax = plt.figure(figsize=(6,6)).add_subplot()
    ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

    ## plot centroids
    for i in range(c.shape[0]):
        color = matplotlib.cm.tab20(i%20)
        ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)
    
    try:
        ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
        print(f"center:",max_actv_[idx],"max_grid:",conf.max(), max_actv_[idx] >= conf.max())
    except:
        pass
    
    maxpt = XX[conf.argmax()]
    ax.scatter(maxpt[0], maxpt[1], color="r", marker= 'o', s=100)
    
    plt.imshow(conf, interpolation='nearest',
           extent=(X1.min(), X1.max(), X2.min(), X2.max()),
           alpha=0.6, cmap='gray',
           aspect='auto', origin='lower')
    
    LVLs = 20
#     LVLs = torch.linspace(0.0, 0.99, 20)
    cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
    ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")
    
    plt.show()

In [None]:
residual0.layer0.scaler, torch.exp(residual0.layer0.scaler), residual0.layer0.bias

In [None]:
residual0.layer0.scale_shift.scaler, residual0.layer0.epsilon

In [None]:
# residual0.layer0.epsilon = 0.5

In [None]:
# residual0.layer0.scaler.data[0,0] = 3.0

In [None]:
# residual0.layer0.scale_shift.scaler = 5.0

## Train 2layer LocalMLP with noisy method

In [None]:
H0 = 20
N_search0 = 1

model = LocalMLP_DTeSM(2, H0, 2, epsilon=0.4, itemp=7.)

In [None]:
model.layer0.scaler.data[0,0] = 2.0

In [None]:
model.layer1.weight.data[:, -1] = model.layer1.weight.data[:, -1]*0

In [None]:
## random init
randidx = torch.randperm(len(xx))[:H0]
model.layer0.centers.data = xx[randidx] 
yy_0 = yy[randidx]
tt_0 = torch.zeros(H0, 2)
for i in range(len(tt_0)):
    tt_0[i, yy_0[i]] = 1.
model.layer1.weight.data[:, :H0] = tt_0.t()

In [None]:
randidx

In [None]:
model(xx).shape

In [None]:
with torch.no_grad():
    yout = model(xx)
accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
accuracy

In [None]:
accs_tup2 = [[accuracy, "init"]]

## Visualize mlp layer

In [None]:
model(XX)
YY = model.layer0.temp_activ
YY = YY.reshape(num_points, num_points, -1)
YY.shape

In [None]:
out = yout.max(dim=1)[1].data.cpu().numpy()
## centroids and shift
c = model.layer0.centers.data.cpu()
# d = model.layer1.weight.data.cpu().t() #+ .cpu()

In [None]:
max_actv = model.layer0(model.layer0.centers.data).data.cpu().diag()
max_actv_ = max_actv.numpy()
max_actv_

In [None]:
for idx in range(YY.shape[-1]):
    conf = YY[:,:,idx]
    conf = conf.data.cpu().numpy().reshape(X1.shape)
    
    ax = plt.figure(figsize=(6,6)).add_subplot()
    ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

    ## plot centroids
    for i in range(c.shape[0]):
        color = matplotlib.cm.tab20(i%20)
        ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)
    
    try:
        ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
        print(f"center:",max_actv_[idx],"max_grid:",conf.max(), max_actv_[idx] >= conf.max())
    except:
        pass
    
    maxpt = XX[conf.argmax()]
    ax.scatter(maxpt[0], maxpt[1], color="r", marker= 'o', s=100)
    
    plt.imshow(conf, interpolation='nearest',
           extent=(X1.min(), X1.max(), X2.min(), X2.max()),
           alpha=0.6, cmap='gray',
           aspect='auto', origin='lower')
    
    LVLs = 20
#     LVLs = torch.linspace(0.0, 0.99, 20)
    cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
    ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")
    
    plt.show()

## Add and Prune

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
imp_est = ImportanceEstimator(model.layer0)

In [None]:
STEPS = 200
for step in range(STEPS):
    ## Add
    randidx = torch.randperm(len(xx))[:N_search0]
    yy_0 = yy[randidx]
    tt_0 = torch.zeros(len(yy_0), 2)
    for i in range(len(yy_0)):
        tt_0[i, yy_0[i]] = 1.
    add_neurons_to_residual(model, xx[randidx], tt_0)
    with torch.no_grad():
        yout = model(xx)
    accuracy_add = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
    
    accs_tup2 += [[accuracy_add, "add"]]    
    
    ## Prune
    none_grad(model)
    imp_est.reset_significance()
    imp_est.attach_hook()
    
    yout = model(xx)
#     yout.register_hook(lambda grad: grad/torch.norm(grad, dim=1, keepdim=True))
    ####################################
#     grad = torch.randn_like(yout)
#     yout.backward(gradient=grad)
    ###################################
    criterion(yout, yy).backward()
    ###################################
    imp_est.accumulate_significance()
    imp_est.remove_hook()
    
    remove_neurons_from_residual(model, imp_est.significance, N_search0)
    with torch.no_grad():
        yout = model(xx)
        loss = criterion(yout, yy)
    accuracy_prune = (yout.max(dim=1)[1] == yy).type(torch.float).mean()
    accs_tup2 += [[accuracy_prune, "prune"]]    

    print(f'Step:{step}, AccAdd:{float(accuracy_add):.2f}, Acc:{float(accuracy_prune):.2f}, Loss:{float(loss):.3f}')

In [None]:
imp_est.significance

### Display Adversarial Examples on 2D

In [None]:
## take some random input as sample
ridx = torch.randperm(len(xx))[:9]
txx = xx[ridx]
txx

In [None]:
txx = torch.autograd.Variable(txx, requires_grad=True)

In [None]:
tyy = yy[ridx]
tyy

In [None]:
model.zero_grad()

In [None]:
loss = criterion(model(txx), 1-tyy)
loss.backward()

In [None]:
txx.grad

In [None]:
num_points = 1000
X1 = np.linspace(-1.5, 1.5, num_points)
X2 = np.linspace(-1.5, 1.5, num_points)
X1, X2 = np.meshgrid(X1, X2)

XX = torch.Tensor(np.c_[X1.reshape(-1), X2.reshape(-1)]).to(device)
XX.shape

In [None]:
# oYY =model(XX).argmax(dim=-1).reshape(num_points, num_points)

oYY =torch.nn.functional.softmax(model(XX), dim=-1)
oYY = oYY@torch.Tensor([0, 1])
oYY = oYY.reshape(num_points, num_points)

YY = model.layer0.temp_activ
YY = YY.reshape(num_points, num_points, -1)
YY.shape

In [None]:
oYY

In [None]:
out = yout.max(dim=1)[1].data.cpu().numpy()
## centroids and shift
c = model.layer0.centers.data.cpu()
# d = model.layer1.weight.data.cpu().t() #+ .cpu()

In [None]:
max_actv = model.layer0(model.layer0.centers.data).data.cpu().diag()
max_actv_ = max_actv.numpy()
max_actv_

In [None]:
idx = model.layer0.num_centers
conf = YY[:,:,idx]
conf = conf.data.cpu().numpy().reshape(X1.shape)

ax = plt.figure(figsize=(6,6)).add_subplot()
ax.scatter(x1, x2, c=out, marker= '.', ec='k',alpha=0.3)

## plot centroids
for i in range(c.shape[0]):
    color = matplotlib.cm.tab20(i%20)
    ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)

try:
    ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
    print(f"center:",max_actv_[idx],"max_grid:",conf.max(), max_actv_[idx] >= conf.max())
except:
    pass

ax.imshow(oYY.data.numpy(), interpolation='nearest',
       extent=(X1.min(), X1.max(), X2.min(), X2.max()),
       alpha=0.3,
       aspect='auto', origin='lower')

# LVLs = 10
LVLs = torch.linspace(0.0, 1.0, 10)**2
cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")

_c = txx.data.numpy()
ax.scatter(_c[:,0], _c[:,1], facecolor="w", edgecolor="r", marker= 'o', s=100)
_d = -txx.grad.data.numpy()
_d = _d/np.linalg.norm(_d, axis=-1, keepdims=True)*0.3
for i in range(len(_d)):
    ax.arrow(_c[i,0], _c[i,1], _d[i,0], _d[i,1], head_width=0.1, head_length=0.1, fc="r", ec="r", linestyle="solid")

plt.savefig("outputs/adversarial-2d-demo.pdf", bbox_inches="tight")
plt.show()

In [None]:
_d

### Plot results

In [None]:
import pickle

In [None]:
with open('outputs/18_epsHighway_accs_noisy.pkl', 'rb') as handle:
    accs_tup3 = pickle.load(handle)

In [None]:
len(accs_tup), len(accs_tup2), len(accs_tup3)

In [None]:
plt.figure(figsize=(10,4))

data_res = np.array([(i, acc) for i, (acc, _) in enumerate(accs_tup)])
plt.plot(data_res[:,0], data_res[:,1], marker='.', linestyle='dotted', color='tab:orange', label=r"$\epsilon$-residual")

data_mlp = np.array([(i, acc) for i, (acc, _) in enumerate(accs_tup2)])
plt.plot(data_mlp[:,0], data_mlp[:,1], marker='.', linestyle='dotted', color='tab:green', label=r"$\epsilon$-mlp")

data_hig = np.array([(i, acc) for i, (acc, _) in enumerate(accs_tup3)])
plt.plot(data_hig[:,0], data_hig[:,1], marker='.', linestyle='dotted', color='tab:blue', label=r"$\epsilon$-highway")

plt.legend()
plt.xlabel("noisy search steps")
plt.ylabel("accuracy")
plt.savefig("outputs/2d_toy_noisy_search.pdf", bbox_inches="tight")