# Spiral classification

In [None]:
import torch
from torch import nn, optim
from math import pi as π  # convenient constant for angles

In [None]:
from res.plot_lib import *

In [None]:
set_default()  # apply plotting style defaults from res.plot_lib

In [None]:
# Use GPU if available for faster full-batch training
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

## Create the data

In [None]:
seed = 12345
# Fix RNG for reproducibility across data and training
torch.manual_seed(seed)
N = 1000  # num_samples_per_class
n = 2     # input dimensions
K = 5     # num_classes
d = 100   # num_hidden_units


In [None]:
# Generate spirals

# Radii grow linearly with t to spread points; small noise on angle for class overlap
t = torch.linspace(0, 1, N)
a = 0.8 * t + 0.2  # amplitude 0.2 -> 1.0
X = list()
y = list()
for k in range(K):
    θ = (2 * t + k) * 2 * π / K + 0.2 * torch.randn(N)
    X.append(torch.stack((a * θ.sin(), a * θ.cos()), dim=1))
    y.append(torch.zeros(N, dtype=torch.long).fill_(k))
X = torch.cat(X)
y = torch.cat(y)

# Keep CPU copies for plotting, send copies to device for training
X_dev, y_dev = X.to(device), y.to(device)

print("Shapes:")
print("X:", tuple(X.size()))
print("y:", tuple(y.size()))


In [None]:
# And visualise them
plot_data(X, y)  # colors correspond to class labels

## Build and train a neural net

In [None]:
learning_rate = 1e-3   # stable default for Adam
lambda_l2 = 1e-5       # small weight decay to regularize

In [None]:
# Model definition
# Toggle ReLU to compare linear vs non-linear decision boundaries; optional 2D bottleneck for visualization
model = nn.Sequential(
    nn.Linear(n, d),
    # nn.ReLU(),  # Comment this line for a linear model
    nn.Linear(d, K)  # (Optional) Comment this line and uncomment the next one to display 2D embeddings below
    # nn.Linear(d, 2), nn.Linear(2, K),
)
model.to(device)  # possibly send to CUDA

# Cross entropy given the linear output
C = nn.CrossEntropyLoss(reduction='none')

# Using Adam optimiser
optimiser = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2

# Full-batch training loop (not mini-batch; deterministic given the fixed seed)
for t in range(2_000):
    # Feed forward to get the linear sum s
    s = model(X_dev)

    # Compute the free energy F and loss L
    F = C(s, y_dev)
    L = F.mean()

    # Zero the gradients
    optimiser.zero_grad()

    # Backward pass to compute and accumulate the gradient of the free energy w.r.t params
    L.backward()

    # Update params
    optimiser.step()

    # Display epoch, L, and accuracy (uses on-device tensors)
    overwrite(f'[EPOCH]: {t}, [LOSS]: {L.item():.6f}, [ACCURACY]: {acc(s, y_dev):.3f}')

# Move model back to CPU for downstream plotting utilities
model_cpu = model.to('cpu')


In [None]:
# Plot trained model
print(model_cpu)
plot_model(X, y, model_cpu)


In [None]:
# (Optional) Plot internal 2D embeddings if available
plot_embeddings(X, y, model_cpu, zoom=10)


In [None]:
# Compute linear output s for a fine grid over the input space

mesh = torch.arange(-1.5, 1.5, 0.01)  # step balances resolution vs compute
xx, yy = torch.meshgrid(mesh, mesh, indexing='ij')
grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=1)
with torch.no_grad():
    s = model_cpu(grid)
    s = s.detach().cpu()  # keep on CPU for plotting


In [None]:
# Choice of free energy (toggle to inspect different energy landscapes)
fe = 'cross-entropy'
# fe = 'negative linear output'

In [None]:
# Switch to non-interactive matplotlib (inline) for static plots
%matplotlib inline
set_default()

In [None]:
# ! mkdir {m}-levels

In [None]:
# Plot 2d energy levels

for k in range(K):
    if fe == 'cross-entropy':
        target = torch.full((s.size(0),), k, dtype=torch.long, device='cpu')  # CPU for plotting
        F = C(s, target)
        F = F.reshape(xx.shape)
        plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (0, 35), (1, 35, 4))

    elif fe == 'negative linear output':
        F = -s[:, k]
        F = F.reshape(xx.shape)
        plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (-20, 20), (-20, 21, 2.5))
        
#     plt.savefig(f'{m}-levels/{k}.png', bbox_inches='tight')

In [None]:
# ! ffmpeg -framerate 1 -i {m}-levels/%d.png -r 25 -vf "crop=trunc(iw/2)*2:trunc(ih/2)*2" -pix_fmt yuv420p {m}-levels.mp4

In [None]:
# Switch to interactive matplotlib
%matplotlib widget

In [None]:
# Cross-entropy (uses the last computed F from loop above)
if fe == 'cross-entropy':
    fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (0, 18), (0, 19, 1), (0, 19, 2))
elif fe == 'negative linear output':
    fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (-30, 20), (-30, 20, 1), (-30, 21, 5))

In [None]:
# ! mkdir {m}-3d-levels

In [None]:
# Spin it around (and maybe save to disk)
δ = 10
for angle in range(0, 360, δ):
    ax.view_init(30, -60 + angle)
    fig.canvas.draw()
#     plt.pause(.001)
#     plt.savefig(f'{m}-3d-levels/{angle:03d}.png', bbox_inches='tight')

In [None]:
# ! ffmpeg -i {m}-3d-levels/%03d.png -vf "crop=trunc(iw/2)*2:trunc(ih/2)*2" -pix_fmt yuv420p {m}-3d-levels.mp4