In [None]:
from IPython.display import HTML, display
display(HTML("<style>.container { width:100% !important; }</style>"))

import os

import torch
import yaml
from model_utils import get_model_name, load_config, save_model
from rhythmic_relationships.data import PartDataset
from rhythmic_relationships.model import VAE
from rhythmic_relationships.train import train
from torch.utils.data import DataLoader


import matplotlib.pyplot as plt
from tqdm import tqdm

from rhythmic_relationships import MODELS_DIR, CHECKPOINTS_DIRNAME

DEVICE = torch.device("mps" if torch.backends.mps.is_built() else "cpu")
CONFIG_FILEPATH = "part_vae_config.yml"


def compute_loss(recons, x, mu, sigma, loss_fn):
    reconstruction_loss = loss_fn(recons, x)
    kld_loss = torch.mean(
        -0.5 * torch.sum(1 + sigma - mu**2 - sigma.exp(), dim=1), dim=0
    )
    return reconstruction_loss + kld_loss

In [None]:
model_name = get_model_name()
print(f"{model_name=}\n")

config = load_config(CONFIG_FILEPATH)
print(yaml.dump(config))

device = DEVICE
x_dim = config["model"]["x_dim"]
y_dim = config["model"]["y_dim"]
clip_gradients = config["clip_gradients"]
num_epochs = config["num_epochs"]

dataset = PartDataset(**config["dataset"])
loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

model = VAE(**config["model"]).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

reduction = config["loss_reduction"]
loss_fn = torch.nn.BCEWithLogitsLoss(reduction=reduction)

print('----\nLayers\n')
for lname, l in model._modules.items():
    print('  ', lname, l)

print('----\nParameters\n')
n_params = 0
for pname, param in model.named_parameters():
    if param.requires_grad:
        print('  ', pname, param.nelement())
        n_params += param.nelement()

print(f'\nTotal number of parameters: {n_params}\n')

In [None]:
losses = []
log_losses = []
ud = [] # update:data ratio

for i, x in enumerate(loader):
    # Forward pass
    x = x.to(device).view(x.shape[0], x_dim)
    x_binary = (x > 0).to(torch.float32)
    x_recon, mu, sigma = model(x_binary)

    # Compute loss
    x_recon_binary = (x_recon > 0).to(torch.float32)
    onset_loss = compute_loss(x_recon_binary, x_binary, mu, sigma, loss_fn)
    velocity_loss = compute_loss(x_recon, x, mu, sigma, loss_fn)
    loss = onset_loss + velocity_loss

    # Backprop
    optimizer.zero_grad()
    loss.backward()

    if clip_gradients:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

    optimizer.step()

    losses.append(loss.item())
    log_losses.append(loss.log10().item())

    if i % 50 == 0:
        print(f"{i:4d}/{len(loader)}: {loss.log10().item():.4f}")

    with torch.no_grad():
        ud.append(
            [((config["lr"] * p.grad).std() / p.data.std()).log10().item() for p in model.parameters()]
        )
    break
print(loss.log10().item())

In [None]:
model.encoder

In [None]:
# Check neuron activations in each linear layer
# We want to make sure that no column is all 0s, because that would indicate a dead neuron
linear_layers = [i for i in model.encoder if isinstance(i, torch.nn.Linear)]
n_layers = len(linear_layers)

fig, ax = plt.subplots(1, n_layers, figsize=(10, 2))
for ix, layer in enumerate(linear_layers):
    abs_weights = layer.weight.detach().cpu().numpy()
    ax[ix].imshow(abs_weights > 0, cmap='gray', interpolation='nearest');
    ax[ix].set_title(str(layer))

In [None]:
plt.hist(model.encoder[0].weight.view(-1).detach().cpu().numpy().tolist(), 50);

In [None]:
plt.plot(log_losses);
# plt.plot(torch.tensor(training_losses).view(-1, 1000).mean(1))

In [None]:
plt.figure(figsize=(5,5))
plt.imshow(h.abs() > 0.99, cmap='gray', interpolation='nearest');

In [None]:
# visualize histograms
plt.figure(figsize=(20, 4))
legends = []

for ix, layer in enumerate(model.encoder):
    t = layer.
    if isinstance(i, torch.nn.ReLU):
        print('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%' % (ix, t.__class__.__name__, t.mean(), t.std(), (t.abs() > 0.97).float().mean()*100))
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].detach(), hy.detach())
        legends.append(f'layer {ix} ({layer.__class__.__name__}')
plt.legend(legends);
plt.title('activation distribution')

In [None]:
plt.figure(figsize=(20, 4))
legends = []

print(f'---\nTotal number of parameters: {n_params}')
for i, (pname, p) in enumerate(model.named_parameters()):
    if p.ndim == 2:
        plt.plot([ud[j][i] for j in range(len(ud))])
        legends.append('param %d' % i)
# Stanford CS231n states that the ratio of weights:updates should be roughly 1e-3
# See https://cs231n.github.io/neural-networks-3/#ratio
plt.plot([0, len(ud)], [-3, -3], 'k')
plt.legend(legends);