In [None]:
import numpy as np
import copy

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import dtnnlib as dtnn

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (9, 8)

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

## Building 2D model

In [None]:
def twospirals(n_points, noise=.5, angle=784):
    """
     Returns the two spirals dataset.
    """
    n = np.sqrt(np.random.rand(n_points,1)) * angle * (2*np.pi)/360
    d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise
    d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise
    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), 
            np.hstack((np.zeros(n_points),np.ones(n_points))))

In [None]:
np.random.seed(1)
x, y = twospirals(300, angle=560)
x, y = x/x.max(axis=0, keepdims=True), y.reshape(-1)
xx, yy = torch.FloatTensor(x), torch.FloatTensor(y.reshape(-1,1))

x1 = xx[:,0]
x2 = xx[:,1]

%matplotlib inline
plt.figure(figsize=(5,5))
plt.scatter(x1, x2, c=y, marker='.')
plt.xlabel("x1")
plt.ylabel("x2")
plt.axis("equal")
plt.show()

In [None]:
xx, yy = xx.to(device), yy.to(device)

## Distance Based Classification

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1, itemp=1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*itemp))
        if epsilon is not None:
            nc += 1
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        self.epsilon = epsilon
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            #################################
            dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon], dim=1)
            #################################
        
        ## scale the dists (1 is optional)
        dists = (1-dists)*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
## reshape for multi-class classification (including epsilon)
yy = yy.reshape(-1).type(torch.LongTensor)

### DTeSM Residual 

In [None]:
class DTeSM(DistanceTransform_Epsilon):
    
    def __init__(self, input_dim, output_dim, epsilon=1.0, itemp=10):
        ### NOTE: Here, not using bias leads to more uniform centroid activation, and easy to compare..
        super().__init__(input_dim, output_dim, bias=False, epsilon=epsilon, itemp=itemp)
        
        self.scale_shift = dtnn.ScaleShift(-1, scaler_init=1, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)
        self.temp_activ = None
        
    def forward(self, x):
        xo = super().forward(x)
        xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        self.temp_activ = xo.data
#         return xo[:, :-1]
        return xo

In [None]:
class LocalResidual_DTeSM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, epsilon=None, itemp=1.0):
        super().__init__()
        self.layer0 = DTeSM(input_dim, hidden_dim, epsilon, itemp)
        if epsilon is not None:
            hidden_dim += 1
        self.layer1 = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x):
        h = self.layer0(x)
        h = x + self.layer1(h)
        return h

In [None]:
torch.manual_seed(123)

In [None]:
model = nn.Sequential(
            LocalResidual_DTeSM(2, 10, epsilon=1.0, itemp=10.0),
            nn.Linear(2, 2, bias=False)
            )

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
yout = model(xx)
yout.shape

In [None]:
%matplotlib inline

for epoch in range(9000):
    yout = model(xx)
    
    loss = criterion(yout, yy)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    model[0].layer1.weight.data[:, -1] = model[0].layer1.weight.data[:, -1]*0 ## zero output epsilon

    if (epoch+1)%100 == 0:
        yout = model(xx)
        accuracy = (yout.max(dim=1)[1] == yy).type(torch.float).mean()                
        print(f'Epoch: {epoch}, Acc:{float(accuracy):.2f}, Loss:{float(loss)}')

    if (epoch+1)%1000 == 0:
        ax = plt.figure(figsize=(6,6)).add_subplot()
        out = yout.max(dim=1)[1].data.cpu().numpy()
        ax.scatter(x1, x2, c=out, marker= '.')
        ## plot centroids
        c = model[0].layer0.centers.data.cpu()
        ax.scatter(c[:,0], c[:,1], color='k', marker= 'x')
        plt.show()

In [None]:
h1 = model[0](xx)
yout = model[1](h1)
out = yout.max(dim=1)[1].data.cpu().numpy()
## centroids and shift
c = model[0].layer0.centers.data.cpu()
d = model[0].layer1.weight.data.cpu().t() #+ net.net[-1].bias.data.cpu()

In [None]:
max_actv = model[0].layer0(model[0].layer0.centers.data).data.cpu()
max_actv = max_actv.diag()#.numpy()
max_actv

In [None]:
! mkdir outputs/14_local_residual/

In [None]:
ax = plt.figure(figsize=(6,6)).add_subplot()
ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

for i in range(c.shape[0]):
    color = matplotlib.cm.tab10(i%20)
    ax.arrow(c[i,0], c[i,1], d[i,0], d[i,1], head_width=0.15, head_length=0.1, fc=color, ec=color, linestyle=(0, (5, 10)))
    ax.scatter(c[i,0], c[i,1], color=color, marker= 'x')

# ax.arrow(0, 0, d[len(c),0], d[len(c),1], head_width=0.15, head_length=0.1, fc="k", ec="k", linestyle=(0, (5, 10)), linewidth=2.0)
    
plt.xlim(-2.5, 2.5)
plt.ylim(-2.5, 2.5)
plt.axis("equal")
plt.savefig("outputs/14_local_residual/local_residual_input.pdf", bbox_inches='tight')
plt.show()

In [None]:
ax = plt.figure(figsize=(6,6)).add_subplot()
ax.scatter(h1.data[:,0], h1.data[:,1], c=out, marker= '.', alpha=0.3)

for i in range(c.shape[0]):
    color = matplotlib.cm.tab10(i%20)
    ax.arrow(c[i,0], c[i,1], d[i,0], d[i,1], head_width=0.15, head_length=0.1, fc=color, ec=color, linestyle=(0, (5, 10)))
    ax.scatter(c[i,0], c[i,1], color=color, marker= 'x')
    
# color = "k"
# ax.arrow(0, 0, d[len(c),0], d[len(c),1], head_width=0.15, head_length=0.1, fc="k", ec="k", linestyle=(0, (5, 10)), linewidth=2.0)
    
plt.xlim(-2.5, 2.5)
plt.ylim(-2.5, 2.5)
plt.axis("equal")
plt.savefig("outputs/14_local_residual/local_residual_output.pdf", bbox_inches='tight')
plt.show()

### Visualize residual neurons

In [None]:
num_points = 1000
X1 = np.linspace(-2, 2, num_points)
X2 = np.linspace(-2, 2, num_points)
X1, X2 = np.meshgrid(X1, X2)

XX = torch.Tensor(np.c_[X1.reshape(-1), X2.reshape(-1)]).to(device)
XX.shape

In [None]:
model[0](XX)
YY = model[0].layer0.temp_activ
YY = YY.reshape(num_points, num_points, -1)
YY.shape

In [None]:
# dir(model[0].layer0)
max_actv = max_actv.numpy()
max_actv

In [None]:
for idx in range(YY.shape[-1]):
    conf = YY[:,:,idx]
    conf = conf.data.cpu().numpy().reshape(X1.shape)
    
    ax = plt.figure(figsize=(6,6)).add_subplot()
    ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

    ## plot centroids
    for i in range(c.shape[0]):
        color = matplotlib.cm.tab20(i%20)
        ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)
    
    try:
        ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
        print(f"center:",max_actv[idx],"max_grid:",conf.max(), max_actv[idx] >= conf.max())
    except:
        pass
    
    maxpt = XX[conf.argmax()]
    ax.scatter(maxpt[0], maxpt[1], color="r", marker= 'o', s=100)
    
    plt.imshow(conf, interpolation='nearest',
           extent=(X1.min(), X1.max(), X2.min(), X2.max()),
           alpha=0.6, cmap='gray',
           aspect='auto', origin='lower')
    
    LVLs = 20
#     LVLs = torch.linspace(0.0, 0.99, 20)
    cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
    ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")
    
    plt.show()

In [None]:
model[0].layer0.scaler, torch.exp(model[0].layer0.scaler), model[0].layer0.bias

In [None]:
### only for epsilon neuron
for idx in [YY.shape[-1]-1]:
    conf = YY[:,:,idx]
    conf = -torch.log(conf)
    
    conf = conf.data.cpu().numpy().reshape(X1.shape)
    
    ax = plt.figure(figsize=(6,6)).add_subplot()
    ax.scatter(x1, x2, c=out, marker= '.', alpha=0.3)

    ## plot centroids
    for i in range(c.shape[0]):
        color = matplotlib.cm.tab20(i%20)
        ax.scatter(c[i,0], c[i,1], color=color, marker= 'x', s=100)
    
    try:
        ax.scatter(c[idx,0], c[idx,1], color="k", marker= 'X', s=100)
        print(f"center:",max_actv[idx],"max_grid:",conf.max(), max_actv[idx] >= conf.max())
    except:
        pass
    
    maxpt = XX[conf.argmax()]
    ax.scatter(maxpt[0], maxpt[1], color="r", marker= 'o', s=100)
    
    plt.imshow(conf, interpolation='nearest',
           extent=(X1.min(), X1.max(), X2.min(), X2.max()),
           alpha=0.6, cmap='gray',
           aspect='auto', origin='lower')
    
    LVLs = 20
#     LVLs = torch.linspace(0.0, 0.99, 20)
    cs = ax.contour(X1, X2, conf, levels=LVLs, linestyles="None", colors="k", linewidths=1, zorder=-2)
    ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt="%1.2f")
    
    plt.show()