## Import

In [5]:
 %matplotlib inline

import random
import torch
from torch import nn, optim
import math
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from matplotlib.pyplot import axhline, axvline, grid, style
import seaborn
seaborn.set(style='ticks')
style.use('seaborn')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

px = 1/plt.rcParams['figure.dpi']  # inches to pixels


## Data

In [6]:
random.seed(12345)
torch.manual_seed(12345)
N = 1000
C = 3
I = 2

X = torch.zeros(N * C, I).to(device)
y = torch.zeros(N * C, dtype=torch.long).to(device)
for c in range(C):
    index = 0
    t = torch.linspace(0, 1, N)    
    inner_var = torch.linspace((2 * math.pi / C) * (c), (2 * math.pi / C) * (2 + c), N) + torch.randn(N) * 0.2
    
    for ix in range(N * c, N * (c + 1)):
        X[ix] = t[index] * torch.FloatTensor((math.sin(inner_var[index]), math.cos(inner_var[index])))
        y[ix] = c
        index += 1

## Model and Training Definition

In [7]:
learning_rate = 1e-3
lambda_l2 = 1e-5

model = nn.Sequential(
    nn.Linear(2, 50, bias = True),
    nn.ReLU(),
    nn.Linear(50, 50, bias = True),
    nn.ReLU(),
    nn.Linear(50, 3, bias = True)
)
model.to(device)

criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lambda_l2)

# Training Def
def epoch(epoch):

    model.to(device)
    y_pred = model(X)

    loss = criterion(y_pred, y)    

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()
    return loss

In [8]:
%matplotlib tk

fig, (ax, axloss) = plt.subplots(ncols=2, figsize=(960*px, 480*px))

ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))
ax.set_xlabel('Age')
ax.set_ylabel('Body Weight')

axloss.set(xlim=(0, 1000), ylim=(0, 1))
axloss.set_xlabel('Epoch')
axloss.set_ylabel('Loss')

mesh = np.arange(-1.1, 1.1, 0.01)
xx, yy = np.meshgrid(mesh, mesh)
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
result = np.zeros([220,220])

model.cpu()
with torch.no_grad():
    Z = model(data).detach()
Z = np.argmax(Z, axis=1).reshape(xx.shape)
p = [ax.contourf(xx, yy, Z, alpha=0.3)]
pl, = axloss.plot([])

sc = ax.scatter(X.cpu().numpy()[:, 0], X.cpu().numpy()[:, 1], c=y.cpu(), s=20, cmap=plt.cm.Paired_r)
vmin = min(xx.min(), yy.min(), Z.min())
vmax = max(xx.max(), yy.max(), Z.max())
def plot_contour():        
    if p[0] != None:
        for collection in p[0].collections:
            collection.remove()
    
    model.cpu()
    with torch.no_grad():
        Z = model(data).detach()
    Z = np.argmax(Z, axis=1).reshape(xx.shape)    

    p[0] = ax.contourf(xx, yy, Z, cmap=plt.cm.Paired_r, alpha=0.3, vmin=vmin, vmax=vmax)
    return p[0]

titleLoss = axloss.text(0.5,0.85, "", bbox={'facecolor':'w', 'alpha':0.5, 'pad':5},
                transform=axloss.transAxes, ha="center")

epochs = np.arange(0, 1000)
yloss = np.zeros(1000)

def animate(i):
    loss = epoch(i)
    yloss[i] = loss.item()

    contour = plot_contour()
    pl.set_data(epochs[:i], yloss[:i])    
    titleLoss.set_text("Epoch %d / Loss %.3f" % (i, loss.item()))

    return pl,

anim = FuncAnimation(fig, animate, frames = epochs, interval=1, repeat=False, blit = False)
# anim.save('./fig/training_spiral1.gif', writer='pillow', fps=30, dpi = 100)
plt.show()