<a href="https://colab.research.google.com/github/simeonbetapudi/DeepLearningAIEthics/blob/main/VAEUNets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Uninstall Colab's bigframes because it conflicts with other installs
%pip uninstall -y bigframes
# Install Lightning, also let's use "rich" progress bars
%pip install -Uqq lightning wandb rich einops

Found existing installation: bigframes 2.23.0
Uninstalling bigframes-2.23.0:
  Successfully uninstalled bigframes-2.23.0
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.5/828.5 kB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.4/243.4 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m832.4/832.4 kB[0m [31m57.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
from torch import nn, optim, utils
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, RandomAffine, RandomErasing
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
# Not advocating Lightning over raw pytorch, but it offers some useful abstractions
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import RichProgressBar
import wandb
import numpy as np
from einops import rearrange

In [None]:
# optional: define additional data augmentation transformers for the dataloader
train_transforms = torchvision.transforms.Compose([
    ToTensor(),
    # uncomment next lines for extra augmentations
    #RandomAffine(degrees=15, translate=(0.1, 0.1)),
    #RandomErasing(p=0.2, scale=(0.02, 0.1))
])

train_ds = MNIST(root='./data', train=True,  download=True, transform=train_transforms)
test_ds  = MNIST(root='./data', train=False, download=True, transform=ToTensor())
val_ds = test_ds  #alias val for test
print(f"Data set lengths: train: {len(train_ds)}, test: {len(test_ds)}")



batch_size = 128   # could make this bigger; note for MNIST on Colab we're disk-speed limited, not GPU-limited
train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=2, shuffle=True, persistent_workers=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=2, shuffle=False, persistent_workers=True)
val_dl = test_dl # alias val <--> test

100%|██████████| 9.91M/9.91M [00:01<00:00, 4.97MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.24MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.5MB/s]

Data set lengths: train: 60000, test: 10000





In [None]:
# @title Code for `show_xs` visualization tool
def show_xs(xs, show_stats=True):
    """A little utility to show one or more images"""
    if type(xs) is not list: xs = list(xs)
    ncols = len(xs)
    fig, axs = plt.subplots(figsize=(3*ncols,2), ncols=ncols, squeeze=False)
    ax = axs.ravel()
    for col, x in enumerate(xs):
        if len(x.shape)>2: x = x[0] # remove any batch dimension
        if show_stats:
            if ncols > 1: print(f"col {col}: ",end="")
            print(f"x.shape = {tuple(x.shape)}, min(x) = {torch.min(x)}, max(x) = {torch.max(x)}")
        digit = ax[col].imshow(x.detach().cpu().numpy(), cmap='gray')
        fig.colorbar(digit, ax=ax[col])
    plt.show()

In [None]:
# @title `test_inference` visualization code
@torch.no_grad()
def test_inference(model, idx=None, return_fig=False):
    import inspect
    model.eval()
    if idx is None: idx = torch.randint(len(test_ds), (1,))[0]
    if isinstance(idx, int): idx = [idx]
    elif isinstance(idx, range): idx = list(idx)
    x_batch = torch.stack([test_ds[i][0] for i in idx]).to(model.device)  # images
    y_batch = torch.tensor([test_ds[i][1] for i in idx]).to(model.device) # labels
    if not model.use_conv: x_batch = x_batch.view(x_batch.size(0), -1)
    if 1==len(inspect.signature(model.forward).parameters): # for ae or vae
        result = model.forward(x_batch)
    else:                                                   # c-vae (later in lesson)
        cond = F.one_hot(y_batch, num_classes=10).float()
        result = model.forward(x_batch, cond)
    z, recon = result[:2]
    recon = recon.view(len(idx), 28, 28)
    fig, axs = plt.subplots(2, len(idx), figsize=(3*len(idx), 4))
    if len(idx) == 1: axs = axs.reshape(2, 1)
    for i in range(len(idx)):
        axs[0,i].imshow(x_batch[i].view(28,28).cpu(), cmap='gray')
        axs[1,i].imshow(recon[i].cpu(), cmap='gray')
        if i == 0:
            axs[0,0].set_ylabel('Input', fontsize=12)
            axs[1,0].set_ylabel('Reconstruction', fontsize=12)
    model.train()
    if return_fig: return fig
    plt.show()

In [None]:
#@title Code for viz tool `plot_latent_space`
@torch.no_grad()
def plot_latent_space(model, n_samples=2000):
    model.eval()
    zs, labels = [], []
    with torch.no_grad():
        for batch_idx, (x_batch, y_batch) in enumerate(test_dl):
            if len(zs) >= n_samples:
                break
            x_batch = x_batch.to(model.device)
            if not model.use_conv:
                x_batch = x_batch.view(x_batch.size(0), -1)
            z_batch = model.encoder(x_batch).cpu()
            if hasattr(model, 'reparam_sample') and z_batch.shape[-1] > model.latent_dim:  # VAE case
                z_batch = z_batch[:, :model.latent_dim]  # just use mu part
            zs.append(z_batch)
            labels.append(y_batch)

    zs = torch.cat(zs)[:n_samples].numpy()
    labels = torch.cat(labels)[:n_samples].numpy()


    if zs.shape[-1] > 2:  # we'll make a 2D plot regardless of how many dims are in the latent space
        from sklearn.decomposition import PCA
        pca = PCA(n_components=2)
        zs_2d = pca.fit_transform(zs)
        title = f"Latent Space (PCA projection, explained variance: {pca.explained_variance_ratio_.sum():.2f})"
    else:
        zs_2d = zs
        title = "Latent Space Visualization"

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(zs_2d[:, 0], zs_2d[:, 1], c=labels, cmap='tab10')
    plt.colorbar(scatter)
    plt.title(title)
    model.train()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
import matplotlib.pyplot as plt
import wandb

class VAEUNet(L.LightningModule):
    def __init__(self, latent_dim=8, act=nn.LeakyReLU, scrunch_factor=1e-3):
        super().__init__()
        # configs
        self.latent_dim = latent_dim
        self.act = act
        self.scrunch_factor = scrunch_factor
        self.use_conv = True

        # -----------------------
        # Encoder (U-Net encoder path)
        # -----------------------
        # enc1: keep spatial 28x28
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, 28, 28]
            nn.BatchNorm2d(32),
            act(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            act(),
        )

        # enc2: downsample to 14x14
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # [B, 64, 14, 14]
            nn.BatchNorm2d(64),
            act(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            act(),
        )

        # enc3: downsample to 7x7 (bottleneck spatial)
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # [B, 128, 7, 7]
            nn.BatchNorm2d(128),
            act(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            act(),
        )

        # flatten size from enc3
        self.enc_out_h = 7
        self.enc_out_w = 7
        self.enc_out_ch = 128
        self.enc_out_dim = self.enc_out_ch * self.enc_out_h * self.enc_out_w  # 128*7*7

        # latent heads
        self.fc_mu = nn.Linear(self.enc_out_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.enc_out_dim, latent_dim)

        # -----------------------
        # Decoder (U-Net decoder path)
        # -----------------------
        # linear from latent to encoded spatial feature
        self.fc_dec = nn.Linear(latent_dim, self.enc_out_dim)

        # decode stage 1: from 128@7x7 -> upsample -> 64@14x14
        self.dec3_up = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # [B,64,14,14]
            nn.BatchNorm2d(64),
            act(),
        )
        # after concatenation with enc2 (64 channels) -> 128 channels in, reduce to 64
        self.dec3_conv = nn.Sequential(
            nn.Conv2d(64 + 64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            act(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            act(),
        )

        # decode stage 2: upsample 14x14 -> 28x28
        self.dec2_up = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # [B,32,28,28]
            nn.BatchNorm2d(32),
            act(),
        )
        # after concatenation with enc1 (32 channels) -> 64 channels in, reduce to 32
        self.dec2_conv = nn.Sequential(
            nn.Conv2d(32 + 32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            act(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            act(),
        )

        # final conv to single channel output
        self.dec_final = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1),  # [B,1,28,28]
            nn.Sigmoid()
        )

    # -----------------------
    # Reparameterization
    # -----------------------
    def reparam_sample(self, mu, log_var):
        "this yields a data value by sampling from the learned gaussian distribution of latents"
        std = torch.exp(0.5*log_var) # the 0.5 is an optional, tunable rescaling factor.
        noise = torch.randn_like(std)       # the gaussian distribution we sample from
        return mu + std * noise

    # -----------------------
    # Forward pass (U-Net VAE)
    # Returns: (z_concat, x_hat_flat, mu, log_var, z_hat)
    # -----------------------
    def forward(self, x):
        # Accept either flattened input [B, 784] or image [B,1,28,28]
        if x.ndim == 2:
            x = x.reshape(x.size(0), 1, 28, 28)  # safe reshape (not view)

        # Encoder path with skip connections
        x1 = self.enc1(x)   # [B,32,28,28]
        x2 = self.enc2(x1)  # [B,64,14,14]
        x3 = self.enc3(x2)  # [B,128,7,7]

        # Bottleneck -> mu, log_var
        z_flat = x3.reshape(x3.size(0), -1)  # [B, enc_out_dim]
        mu = self.fc_mu(z_flat)
        log_var = self.fc_logvar(z_flat)
        z_hat = self.reparam_sample(mu, log_var)

        # Decoder: linear -> reshape -> upsample + skip connections
        dec_in = self.fc_dec(z_hat)                        # [B, enc_out_dim]
        dec_in = dec_in.reshape(dec_in.size(0), 128, 7, 7)  # [B,128,7,7]

        d3 = self.dec3_up(dec_in)        # [B,64,14,14]
        d3 = torch.cat([d3, x2], dim=1)  # [B,128,14,14]
        d3 = self.dec3_conv(d3)          # [B,64,14,14]

        d2 = self.dec2_up(d3)            # [B,32,28,28]
        d2 = torch.cat([d2, x1], dim=1)  # [B,64,28,28]
        d2 = self.dec2_conv(d2)          # [B,32,28,28]

        x_hat_img = self.dec_final(d2)   # [B,1,28,28]

        # Flatten reconstruction to match existing pred_and_log (which expects flattened x)
        x_hat_flat = x_hat_img.reshape(x_hat_img.size(0), -1)  # [B, 784]

        # For compatibility with your previous model, return concatenated z (mu|log_var) as first element
        z_concat = torch.cat([mu, log_var], dim=1)  # [B, 2*latent_dim]
        return z_concat, x_hat_flat, mu, log_var, z_hat

    def encode(self, x):
        # Accept either flattened input [B, 784] or image [B,1,28,28]
        if x.ndim == 2:
            x = x.reshape(x.size(0), 1, 28, 28)  # safe reshape (not view)

        # Encoder path with skip connections
        x1 = self.enc1(x)   # [B,32,28,28]
        x2 = self.enc2(x1)  # [B,64,14,14]
        x3 = self.enc3(x2)  # [B,128,7,7]

        # Bottleneck -> mu, log_var
        z_flat = x3.reshape(x3.size(0), -1)  # [B, enc_out_dim]
        mu = self.fc_mu(z_flat)
        log_var = self.fc_logvar(z_flat)
        z_hat = self.reparam_sample(mu, log_var)
        return z_hat

    def decode(self, x):
        # Decoder: linear -> reshape -> upsample + skip connections
        dec_in = self.fc_dec(z_hat)                        # [B, enc_out_dim]
        dec_in = dec_in.reshape(dec_in.size(0), 128, 7, 7)  # [B,128,7,7]

        d3 = self.dec3_up(dec_in)        # [B,64,14,14]
        d3 = torch.cat([d3, x2], dim=1)  # [B,128,14,14]
        d3 = self.dec3_conv(d3)          # [B,64,14,14]

        d2 = self.dec2_up(d3)            # [B,32,28,28]
        d2 = torch.cat([d2, x1], dim=1)  # [B,64,28,28]
        d2 = self.dec2_conv(d2)          # [B,32,28,28]

        x_hat_img = self.dec_final(d2)   # [B,1,28,28]

        # Flatten reconstruction to match existing pred_and_log (which expects flattened x)
        x_hat_flat = x_hat_img.reshape(x_hat_img.size(0), -1)  # [B, 784]

        # For compatibility with your previous model, return concatenated z (mu|log_var) as first element
        z_concat = torch.cat([mu, log_var], dim=1)  # [B, 2*latent_dim]
        return z_concat, x_hat_flat, mu, log_var, z_hat

    # -----------------------
    # Loss and logging (keeps your previous pred_and_log API)
    # -----------------------
    def pred_and_log(self, batch, batch_idx, log_prefix=''):
        x, y = batch
        # accept flattened or image inputs
        if x.ndim == 2:
            x_img = x.reshape(x.size(0), 1, 28, 28)
        else:
            x_img = x

        z, x_hat, mu, log_var, z_hat = self.forward(x_img)

        # recon loss (BCE) expects flattened tensors as you used previously
        recon_loss = F.binary_cross_entropy(x_hat, x_img.reshape(x_img.size(0), -1))
        kl_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        rescaled_kl_loss = kl_loss * self.scrunch_factor
        loss = recon_loss + rescaled_kl_loss

        # logging (same keys as before)
        self.log(f'{log_prefix}loss', loss, prog_bar=True)
        self.log(f'{log_prefix}recon_loss', recon_loss)
        self.log(f'{log_prefix}kl_loss', kl_loss)
        self.log(f'{log_prefix}rescaled_kl_loss', rescaled_kl_loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.pred_and_log(batch, batch_idx, log_prefix='train/')

    def validation_step(self, batch, batch_idx):
        return self.pred_and_log(batch, batch_idx, log_prefix='val/')

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=5e-4)

    def on_epoch_start(self):
        print('\n')

    def on_validation_epoch_end(self, demo_every=1):
        # keep the original behavior you had for logging reconstructions to wandb
        if self.current_epoch % demo_every == 0:
            # assumes you still have a test_inference(...) util that accepts this model
            fig = test_inference(self, idx=range(5), return_fig=True)
            self.logger.experiment.log({"reconstructions": wandb.Image(fig), "epoch": self.current_epoch})
            plt.close(fig)

vae = VAEUNet()
model = vae

In [None]:
wandb.finish()
wandb_logger = WandbLogger(log_model="all", project='vaeunet_tut')
wandb_logger.watch(model) # this thing complains too much upon re-runs, just ignore it and keep going

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msimeonbetapudi[0m ([33msimeonbetapudi-belmont-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [None]:
epochs = 25  # VAEs require more steps to train than vanilla AEs, due to dual-objective loss
trainer = L.Trainer(max_epochs=epochs, devices="auto", logger=wandb_logger, callbacks=RichProgressBar(leave=True))
trainer.fit(model=vae, train_dataloaders=train_dl, val_dataloaders=val_dl)
wandb.finish()

INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

INFO: `Trainer.fit` stopped: `max_epochs=25` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=25` reached.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/kl_loss,██▇▅▅▄▅▇▅▅▄▆▅▅▅▃▄▄▃▃▅▂▅▃▂▂▂▂▂▂▂▂▄▁▁▁▁▁▁▁
train/loss,█▄▃▃▁▁▂▂▁▁▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▂▂▂▁▂▂▁▁▁▁▁▁▁▁▂
train/recon_loss,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/rescaled_kl_loss,█▆▄▃▃▃▂▄▃▃▂▂▄▃▃▂▂▂▂▃▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█████
val/kl_loss,█▅▄▅▃▅█▅▃▄▃▃▄▂▂▂▂▂▂▁▁▁▁▂▁
val/loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/rescaled_kl_loss,█▅▄▅▃▅█▅▃▄▃▃▄▂▂▂▂▂▂▁▁▁▁▂▁

0,1
epoch,24.0
train/kl_loss,7e-05
train/loss,0.05821
train/recon_loss,0.05821
train/rescaled_kl_loss,0.0
trainer/global_step,11724.0
val/kl_loss,0.00014
val/loss,0.05814
val/recon_loss,0.05814
val/rescaled_kl_loss,0.0
