In [17]:
import os

from datetime import datetime

import torch
from torch.utils.tensorboard.writer import SummaryWriter
from quinnVAE.heidenreich.vae import VAE
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import v2


batch_size = 128
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Lambda(lambda x: x.view(-1) - 0.5),
    ]
)

# Download and load the training data
train_data = datasets.MNIST(
    "~/.pytorch/MNIST_data/",
    download=True,
    train=True,
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    "~/.pytorch/MNIST_data/",
    download=True,
    train=False,
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
)

In [18]:
learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 50
latent_dim = 2
hidden_dim = 512

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
model = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=weight_decay)
writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')

In [19]:
def train(model, dataloader, optimizer, prev_updates, writer=None):
    """
    Trains the model on the given data.

    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode

    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx

        data = data.to(device)

        optimizer.zero_grad()  # Zero the gradients

        output = model(data)  # Forward pass
        loss = output.loss

        loss.backward()

        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1.0 / 2)

            print(
                f"Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) Grad: {total_norm:.4f}"
            )

            if writer is not None:
                global_step = n_upd
                writer.add_scalar("Loss/Train", loss.item(), global_step)
                writer.add_scalar(
                    "Loss/Train/BCE", output.loss_recon.item(), global_step
                )
                writer.add_scalar("Loss/Train/KLD", output.loss_kl.item(), global_step)
                writer.add_scalar("GradNorm/Train", total_norm, global_step)

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()  # Update the model parameters

    return prev_updates + len(dataloader)


def test(model, dataloader, cur_step, writer=None):
    """
    Tests the model on the given data.

    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
        writer: The TensorBoard writer.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    test_recon_loss = 0
    test_kl_loss = 0

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Testing"):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data

            output = model(data, compute_loss=True)  # Forward pass

            test_loss += output.loss.item()
            test_recon_loss += output.loss_recon.item()
            test_kl_loss += output.loss_kl.item()

    test_loss /= len(dataloader)
    test_recon_loss /= len(dataloader)
    test_kl_loss /= len(dataloader)
    print(
        f"====> Test set loss: {test_loss:.4f} (BCE: {test_recon_loss:.4f}, KLD: {test_kl_loss:.4f})"
    )

    if writer is not None:
        writer.add_scalar("Loss/Test", test_loss, global_step=cur_step)
        writer.add_scalar(
            "Loss/Test/BCE", output.loss_recon.item(), global_step=cur_step
        )
        writer.add_scalar("Loss/Test/KLD", output.loss_kl.item(), global_step=cur_step)

        # Log reconstructions
        writer.add_images(
            "Test/Reconstructions",
            output.x_recon.view(-1, 1, 28, 28),
            global_step=cur_step,
        )
        writer.add_images(
            "Test/Originals", data.view(-1, 1, 28, 28), global_step=cur_step
        )

        # Log random samples from the latent space
        z = torch.randn(16, latent_dim).to(device)
        samples = model.decode(z)
        writer.add_images(
            "Test/Samples", samples.view(-1, 1, 28, 28), global_step=cur_step
        )

In [20]:
prev_updates = 0
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

Epoch 1/50


  1%|          | 4/469 [00:00<00:44, 10.40it/s]

Step 0 (N samples: 0), Loss: 543.7944 (Recon: 543.5302, KL: 0.2643) Grad: 13.0586


 22%|██▏       | 102/469 [00:04<00:16, 22.89it/s]

Step 100 (N samples: 12,800), Loss: 192.5526 (Recon: 190.6017, KL: 1.9509) Grad: 43.1330


 43%|████▎     | 203/469 [00:07<00:11, 23.08it/s]

Step 200 (N samples: 25,600), Loss: 185.9642 (Recon: 183.7439, KL: 2.2203) Grad: 21.2906


 65%|██████▍   | 303/469 [00:11<00:06, 24.18it/s]

Step 300 (N samples: 38,400), Loss: 179.5865 (Recon: 177.1703, KL: 2.4161) Grad: 22.2967


 87%|████████▋ | 406/469 [00:15<00:02, 28.81it/s]

Step 400 (N samples: 51,200), Loss: 175.5581 (Recon: 172.2065, KL: 3.3516) Grad: 16.5375


100%|██████████| 469/469 [00:18<00:00, 26.05it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 32.33it/s]


====> Test set loss: 172.6249 (BCE: 168.7793, KLD: 3.8455)
Epoch 2/50


  8%|▊         | 36/469 [00:01<00:15, 28.01it/s]

Step 500 (N samples: 64,000), Loss: 172.4530 (Recon: 168.2575, KL: 4.1956) Grad: 31.1336


 28%|██▊       | 133/469 [00:10<00:30, 10.95it/s]

Step 600 (N samples: 76,800), Loss: 168.0890 (Recon: 163.4968, KL: 4.5922) Grad: 27.6743


 49%|████▉     | 232/469 [00:18<00:24,  9.62it/s]

Step 700 (N samples: 89,600), Loss: 156.3945 (Recon: 151.4286, KL: 4.9660) Grad: 35.0010


 72%|███████▏  | 337/469 [00:24<00:06, 21.27it/s]

Step 800 (N samples: 102,400), Loss: 162.8124 (Recon: 157.9067, KL: 4.9057) Grad: 44.3753


 93%|█████████▎| 434/469 [00:34<00:02, 14.68it/s]

Step 900 (N samples: 115,200), Loss: 156.6857 (Recon: 151.3972, KL: 5.2885) Grad: 59.2117


100%|██████████| 469/469 [00:36<00:00, 12.85it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 35.62it/s]


====> Test set loss: 156.2012 (BCE: 151.0607, KLD: 5.1405)
Epoch 3/50


 14%|█▍        | 67/469 [00:03<00:18, 21.19it/s]

Step 1,000 (N samples: 128,000), Loss: 158.8552 (Recon: 153.6088, KL: 5.2464) Grad: 116.3155


 36%|███▌      | 167/469 [00:08<00:14, 21.17it/s]

Step 1,100 (N samples: 140,800), Loss: 149.2159 (Recon: 143.8959, KL: 5.3199) Grad: 52.4785


 57%|█████▋    | 267/469 [00:13<00:08, 24.43it/s]

Step 1,200 (N samples: 153,600), Loss: 148.2742 (Recon: 142.2313, KL: 6.0429) Grad: 72.5730


 78%|███████▊  | 365/469 [00:18<00:03, 26.61it/s]

Step 1,300 (N samples: 166,400), Loss: 155.0316 (Recon: 149.5329, KL: 5.4986) Grad: 72.3593


100%|█████████▉| 467/469 [00:23<00:00, 27.41it/s]

Step 1,400 (N samples: 179,200), Loss: 153.5122 (Recon: 147.7253, KL: 5.7869) Grad: 144.9660


100%|██████████| 469/469 [00:23<00:00, 19.99it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 37.23it/s]


====> Test set loss: 151.6138 (BCE: 146.1112, KLD: 5.5026)
Epoch 4/50


 20%|█▉        | 92/469 [00:05<00:16, 22.21it/s]

Step 1,500 (N samples: 192,000), Loss: 146.3910 (Recon: 140.6899, KL: 5.7011) Grad: 102.9026


 42%|████▏     | 197/469 [00:12<00:17, 15.98it/s]

Step 1,600 (N samples: 204,800), Loss: 147.9927 (Recon: 142.0916, KL: 5.9012) Grad: 105.0604


 63%|██████▎   | 295/469 [00:18<00:17,  9.79it/s]

Step 1,700 (N samples: 217,600), Loss: 149.6030 (Recon: 143.7484, KL: 5.8546) Grad: 85.4328


 84%|████████▍ | 394/469 [00:23<00:03, 24.10it/s]

Step 1,800 (N samples: 230,400), Loss: 141.2421 (Recon: 135.3399, KL: 5.9021) Grad: 68.3273


100%|██████████| 469/469 [00:28<00:00, 16.25it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 34.11it/s]


====> Test set loss: 147.9984 (BCE: 142.1282, KLD: 5.8703)
Epoch 5/50


  6%|▌         | 27/469 [00:01<00:21, 20.21it/s]

Step 1,900 (N samples: 243,200), Loss: 148.6717 (Recon: 142.7024, KL: 5.9693) Grad: 74.5154


 27%|██▋       | 126/469 [00:07<00:33, 10.29it/s]

Step 2,000 (N samples: 256,000), Loss: 142.3420 (Recon: 136.1892, KL: 6.1528) Grad: 83.8534


 48%|████▊     | 226/469 [00:14<00:12, 20.23it/s]

Step 2,100 (N samples: 268,800), Loss: 135.3863 (Recon: 129.5072, KL: 5.8792) Grad: 71.4240


 70%|██████▉   | 328/469 [00:20<00:06, 20.38it/s]

Step 2,200 (N samples: 281,600), Loss: 146.0477 (Recon: 139.7527, KL: 6.2949) Grad: 91.4337


 91%|█████████ | 425/469 [00:29<00:03, 13.24it/s]

Step 2,300 (N samples: 294,400), Loss: 147.2084 (Recon: 141.0905, KL: 6.1179) Grad: 62.8742


100%|██████████| 469/469 [00:32<00:00, 14.40it/s]
Testing: 100%|██████████| 79/79 [00:03<00:00, 25.98it/s]


====> Test set loss: 147.9070 (BCE: 141.7283, KLD: 6.1787)
Epoch 6/50


 13%|█▎        | 60/469 [00:03<00:17, 23.27it/s]

Step 2,400 (N samples: 307,200), Loss: 141.5906 (Recon: 135.5733, KL: 6.0173) Grad: 61.5298


 34%|███▎      | 158/469 [00:09<00:20, 14.91it/s]

Step 2,500 (N samples: 320,000), Loss: 145.6479 (Recon: 139.6842, KL: 5.9637) Grad: 87.3049


 55%|█████▌    | 258/469 [00:15<00:15, 13.57it/s]

Step 2,600 (N samples: 332,800), Loss: 147.1515 (Recon: 140.8918, KL: 6.2597) Grad: 115.9016


 76%|███████▌  | 357/469 [00:20<00:05, 20.89it/s]

Step 2,700 (N samples: 345,600), Loss: 140.8871 (Recon: 134.8356, KL: 6.0516) Grad: 95.4267


 97%|█████████▋| 457/469 [00:27<00:01, 11.51it/s]

Step 2,800 (N samples: 358,400), Loss: 146.0135 (Recon: 139.5813, KL: 6.4322) Grad: 139.0553


100%|██████████| 469/469 [00:28<00:00, 16.54it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 34.97it/s]


====> Test set loss: 145.5326 (BCE: 139.5785, KLD: 5.9541)
Epoch 7/50


 19%|█▉        | 89/469 [00:04<00:27, 13.87it/s]

Step 2,900 (N samples: 371,200), Loss: 142.5108 (Recon: 136.5913, KL: 5.9195) Grad: 108.3269


 40%|████      | 189/469 [00:10<00:18, 15.07it/s]

Step 3,000 (N samples: 384,000), Loss: 143.1807 (Recon: 136.9395, KL: 6.2412) Grad: 86.2966


 61%|██████▏   | 288/469 [00:16<00:08, 22.11it/s]

Step 3,100 (N samples: 396,800), Loss: 151.5217 (Recon: 145.1451, KL: 6.3766) Grad: 88.7741


 83%|████████▎ | 389/469 [00:21<00:03, 22.23it/s]

Step 3,200 (N samples: 409,600), Loss: 147.9424 (Recon: 141.6516, KL: 6.2909) Grad: 85.4856


100%|██████████| 469/469 [00:25<00:00, 18.38it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 36.02it/s]


====> Test set loss: 145.3576 (BCE: 139.1467, KLD: 6.2109)
Epoch 8/50


  4%|▍         | 19/469 [00:01<00:28, 15.86it/s]

Step 3,300 (N samples: 422,400), Loss: 149.3853 (Recon: 143.0665, KL: 6.3188) Grad: 81.4185


 26%|██▌       | 121/469 [00:05<00:17, 20.35it/s]

Step 3,400 (N samples: 435,200), Loss: 149.8342 (Recon: 143.5990, KL: 6.2353) Grad: 47.8728


 47%|████▋     | 221/469 [00:11<00:11, 20.91it/s]

Step 3,500 (N samples: 448,000), Loss: 145.5519 (Recon: 139.3127, KL: 6.2392) Grad: 108.5208


 68%|██████▊   | 321/469 [00:16<00:08, 17.71it/s]

Step 3,600 (N samples: 460,800), Loss: 144.0203 (Recon: 137.8681, KL: 6.1522) Grad: 68.2381


 90%|████████▉ | 420/469 [00:22<00:02, 20.08it/s]

Step 3,700 (N samples: 473,600), Loss: 138.6896 (Recon: 132.3293, KL: 6.3603) Grad: 67.9654


100%|██████████| 469/469 [00:24<00:00, 19.03it/s]
Testing: 100%|██████████| 79/79 [00:02<00:00, 27.99it/s]


====> Test set loss: 147.4522 (BCE: 141.2852, KLD: 6.1670)
Epoch 9/50


 11%|█         | 52/469 [00:03<00:24, 17.01it/s]

Step 3,800 (N samples: 486,400), Loss: 151.0556 (Recon: 144.7807, KL: 6.2748) Grad: 160.8693


 32%|███▏      | 152/469 [00:08<00:15, 19.91it/s]

Step 3,900 (N samples: 499,200), Loss: 145.5035 (Recon: 139.1232, KL: 6.3803) Grad: 113.0111


 53%|█████▎    | 249/469 [00:15<00:48,  4.52it/s]

Step 4,000 (N samples: 512,000), Loss: 142.6138 (Recon: 136.4306, KL: 6.1832) Grad: 115.2933


 55%|█████▌    | 259/469 [00:22<00:18, 11.62it/s]


KeyboardInterrupt: 