In [None]:
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import dtnnlib as dtnn

In [None]:
num_points = 75
# X1 = np.linspace(-2.5, 1.9, num_points)
X1 = np.linspace(-2.5, 2.5, num_points)
X2 = np.linspace(-2.5, 3, num_points)
# X2 = np.linspace(-2.2, 2.1, num_points)
X1, X2 = np.meshgrid(X1, X2)

Y = np.sin(np.sqrt(X1**2 + X2**2))*2-1. - 0.1*(X1)+0.02*(X2)
# Y = np.sqrt(X1**2 + X2**2)

####Scaling the data to range -1,1
X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1
Y = 2*(Y - Y.min())/(Y.max() - Y.min()) -1

In [None]:
x1 = X1.reshape(-1)
x2 = X2.reshape(-1)
y = Y.reshape(-1)

xx = np.c_[x1, x2]
yy = Y.reshape(-1,1)
xx, yy = torch.FloatTensor(xx), torch.FloatTensor(yy)


%matplotlib inline
fig = plt.figure(figsize=(10,7))
ax = plt.axes(projection='3d')
ax.plot_surface(X1, X2, Y, cmap='plasma')
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('Y')
plt.show()

In [None]:
## Define piecewise MLP
torch.manual_seed(103)

h = 8
net = nn.Sequential(
            nn.Linear(2, h),
            nn.ReLU(),
            nn.Linear(h, 1),
)

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [None]:
xx, yy = xx.to(device), yy.to(device)
net.to(device)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [None]:
# fig = plt.figure(figsize=(9,8))
# ax = plt.axes(projection='3d')

for epoch in range(5000):

    yout = net(xx)
    loss =criterion(yout, yy)
    optimizer.zero_grad()
    
    loss.backward()
    optimizer.step()
    
    if epoch%200 == 0:
        error = float(loss)
        print(f'Epoch:{epoch} | Error:{error}')
#         ax.clear()
        
#         ax.scatter(X1, X2, Y)
#         yout_ = yout.reshape(Y.shape)
#         ax.scatter(X1, X2, yout_, color='r', marker='.')
#         ax.set_xlabel('X1')
#         ax.set_ylabel('X2')
#         ax.set_zlabel('Y')
        
#         fig.canvas.draw()
#         plt.pause(0.01)

In [None]:
# %matplotlib tk

yout_ = yout.data.cpu().reshape(Y.shape)

fig = plt.figure(figsize=(9,8))
ax = plt.axes(projection='3d')
ax.scatter(X1, X2, Y, marker= '.')
ax.scatter(X1, X2, yout_, color='r', marker='.')
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('Y')
# plt.pause(10)
plt.show()

#### Visualize individual Neuron

In [None]:
i = -1

In [None]:
### run below iteratively
i += 1
a = net[1](net[0](xx)).data[:,i]

# %matplotlib tk
%matplotlib inline
fig = plt.figure(figsize=(7,5))
ax = plt.axes(projection='3d')
ax.plot_surface(X1, X2, a.reshape(X1.shape), cmap='plasma')
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('Y')
# plt.pause(10)
plt.show()

##### Extracting the pieces

In [None]:
###### equation of line where (wx+b) == 0

## (x1w1 + x2w2 + b = 0) ... solving
## x2 = (-b-w1x1)/w2

## lets draw line given x1 is in -2 and 2.
lines = []
vecs = []
mcs = []
for i in range(len(net[0].bias)):
    ## for x1 = -2 and 2
    with torch.no_grad():
        w1, w2 = net[0].weight.t()[:, i]
        b = net[0].bias[i]
        x2_a = -(b+w1*-2)/w2
        x2_b = -(b+w1*2)/w2
        lines.append([(-2, x2_a), (2, x2_b)])
#         lines.append([(-1, x2_a), (1, x2_b)])
        
        vecs.append([w1, w2])
        ## x1w1 + x2w2 + b = 0
        ## x2 = (-x1w1-b)/w2
        mcs.append([-w1/w2, -b/w2])

In [None]:
lines = torch.Tensor(lines).cpu().numpy()
vecs = torch.Tensor(vecs).cpu().numpy()
mcs = torch.Tensor(mcs).cpu().numpy()
lines

In [None]:
## make lines with max norm of 1
lines = lines.reshape(-1, 2)
# lines = lines/np.linalg.norm(lines, axis=1, keepdims=True)
lines = lines.reshape(-1, 2,2)

In [None]:
vecs = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)

In [None]:
## interpolation of points along the lines
# interp = np.linspace(0, 1, 20)
# interp.shape, lines.shape

In [None]:
# interp = interp.reshape(-1,1)

In [None]:
lines[0]

In [None]:
vecs[0]

In [None]:
## plot the lines
%matplotlib inline
_x0, _x1 = 0.75, 0.5
# _x0, _x1 = 0.20, 1.25
actv = net[:2](torch.Tensor([[_x0, _x1]])).data.reshape(-1)

plt.figure(figsize=(6,5))
for i, line in enumerate(lines):
    if i == 0: continue
    color = matplotlib.cm.tab10(i)
    plt.plot(line[:,0], line[:,1], c=color, lw=2)
#     plt.arrow(0, 0, vecs[i][0]/4, vecs[i][1]/4, head_width=0.05, head_length=0.04, fc=color, ec=color, linestyle='solid', alpha=0.5)
    
    distance = ((lines[i, 0] - lines[i, 1])**2).sum()**0.5
    interp = np.linspace(0, 1, int(distance)*10).reshape(-1,1)
    pts = lines[i,0]*interp + lines[i,1]*(1-interp)
    for j in range(len(pts)): 
        plt.arrow(pts[j,0], pts[j,1], vecs[i][0]/10, vecs[i][1]/10, head_width=0.05, head_length=0.04, fc=color, ec=color, linestyle='solid', alpha=0.5)
    ## Perpendicular to the line from point
    _m, _c = mcs[i,0], mcs[i,1]
    _x = (_x0 + _m*(_x1 - _c))/(_m*_m + 1)
    _y = _m*_x + _c
    dist = ((_x0-_x)**2+(_x1-_y)**2)**0.5
    dist = dist*(actv[i]>0)
    plt.scatter(_x, _y, edgecolors='k', facecolors='yellow', s=10, lw=1, marker='o', zorder=99)
    plt.plot([_x, _x0], [_y, _x1], lw=dist*4, color='k', zorder=10)
#     plt.plot([_x, _x0], [_y, _x1], lw=actv[i]*4, color='k', zorder=10)
        
plt.scatter(_x0, _x1, edgecolors='k', facecolors='yellow', s=100, lw=1, marker='X', zorder=99)

plt.axis("equal")
plt.xlim(-1, 1.5)
plt.ylim(-1, 1.5)
plt.xlabel('X1')
plt.ylabel('X2')
# plt.tick_params(left = False, right = False , labelleft = False ,
#                 labelbottom = False, bottom = False)
plt.savefig("./outputs/00_neuron_viz/linear_neurons_2d.pdf", bbox_inches='tight')

In [None]:
xx.min(dim=0)[0], xx.max(dim=0)[0]

In [None]:
# my_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("cust0", ['green', 'white'], gamma=0.4)
# my_cmap

## Visualize Radial Neurons

In [None]:
class One_Actv(nn.Module):
    def __init__(self):
        super().__init__()
        self.scaler = nn.Parameter(torch.ones(1)*0)

    def forward(self, x):
        x = x*torch.exp(self.scaler)
        return torch.exp(-x**2)

In [None]:
## Define piecewise MLP
torch.manual_seed(103)

h = 6
net = nn.Sequential(
            dtnn.DistanceTransformBase(2, h),
            One_Actv(),
            nn.Linear(h, 1),
)

In [None]:
device = torch.device("cpu")

In [None]:
xx, yy = xx.to(device), yy.to(device)
net.to(device)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.01) ## need higher lr
criterion = nn.MSELoss()

In [None]:
for epoch in range(5000):

    yout = net(xx)
    loss =criterion(yout, yy)
    optimizer.zero_grad()
    
    loss.backward()
    optimizer.step()
    
    if epoch%200 == 0:
        error = float(loss)
        print(f'Epoch:{epoch} | Error:{error}')

In [None]:
%matplotlib inline
# %matplotlib tk


yout_ = yout.data.cpu().reshape(Y.shape)

fig = plt.figure(figsize=(9,8))
ax = plt.axes(projection='3d')
ax.scatter(X1, X2, Y, marker= '.')
ax.scatter(X1, X2, yout_, color='r', marker='.')
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('Y')
plt.show()

In [None]:
centers = net[0].centers.data.cpu()
centers

In [None]:
actf = net[1](net[0](xx)).data.cpu()
# actf = net[0](xx).data.cpu()

In [None]:
index = np.random.randint(len(actf))
index

In [None]:
# index = 3660, 603, 5012, 4993
index = 4988

In [None]:
actf[index]

In [None]:
fig = plt.figure(figsize=(6,5))
ax = fig.gca()
alpha = 0.3

# index = np.random.randint(len(actf))
_x0, _x1 = xx[index][0].item(), xx[index][1].item()

for i, cent in enumerate(centers):
    color = matplotlib.cm.tab10(i)
    for scale in [2, 4, 8, 16]:
        ell = matplotlib.patches.Ellipse(cent, scale*0.07, scale*0.07, edgecolor=color, facecolor=color, lw=2)
#             ell.set_clip_box(ax.bbox)
        ell.set_alpha(alpha/np.log2(scale))
        ax.add_artist(ell)
        pass

#     for scale, alp in zip([2, 4, 8, 16, 32], [1, 0.5, 0.25, 0.125, 0.025]):
#         ell = matplotlib.patches.Ellipse(cent, scale*0.1, scale*0.1, 180., edgecolor=color, facecolor=color, lw=1)
#         ell.set_alpha(alpha*alp)
#         ax.add_artist(ell)
#         pass
    
    plt.plot([cent[0], _x0], [cent[1], _x1], lw=actf[index][i].item()*3, color='k', zorder=10)
    plt.scatter(cent[0], cent[1], color=color, zorder=100)
    pass
plt.scatter(_x0, _x1, edgecolors='k', facecolors='yellow', s=100, lw=1, marker='X', zorder=99)

plt.axis("equal")
# plt.xlim(-1.0, 1.0)
# plt.ylim(-1.0, 1.0)
plt.xlim(-1.75, 1.75)
plt.ylim(-1.5, 2.0)

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)
plt.savefig("./outputs/00_neuron_viz/dist_rbf_neurons_2d.pdf", bbox_inches='tight')