In [1]:
pip install torch-fidelity wandb

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0


In [2]:
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import time
from torch_fidelity import calculate_metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [3]:
config = {
    # dataset
    "dataset": "CIFAR-10",
    "image_size": 32,
    "channels": 3,

    # model
    "model": "beta-VAE",
    "latent_dim": 128,
    "encoder_channels": [32, 64, 128],

    # training
    "batch_size": 128,
    "epochs": 30,
    "lr": 2e-4,
    "recon_loss": "MSE",
    "beta": 0.25,

    # evaluation
    "fid_every": 5,
    "fid_samples": 1000,
    "kid_subset_size": 300,

    # logging
    "log_images": True,
    "num_log_images": 16,
}

In [4]:
run_name = (
    f"VAE_beta{config['beta']}_"
    f"z{config['latent_dim']}_"
    f"lr{config['lr']}"
)

wandb.init(
    project="generative-modeling-on-cifar-10",
    name=run_name,
    config=config
)

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgioeba[0m ([33mgioeba-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
real_folder = "/tmp/cifar10_real"
gen_folder = "/tmp/cifar10_gen"
os.makedirs(real_folder, exist_ok=True)
os.makedirs(gen_folder, exist_ok=True)

In [6]:
real_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=T.ToTensor()
)

if len(os.listdir(real_folder)) == 0:
    print("Saving real CIFAR-10 images...")
    for i in tqdm(range(10000)):
        img, _ = real_dataset[i]
        vutils.save_image(img, os.path.join(real_folder, f"real_{i}.png"))
else:
    print("Real images already exist, skipping.")

100%|██████████| 170M/170M [00:09<00:00, 18.3MB/s]


Saving real CIFAR-10 images...


100%|██████████| 10000/10000 [00:09<00:00, 1043.21it/s]


In [7]:
train_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

train_set = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=False,
    transform=train_transform
)

train_loader = DataLoader(
    train_set,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

In [8]:
class VAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        ch = config["encoder_channels"]
        latent_dim = config["latent_dim"]

        self.encoder = nn.Sequential(
            nn.Conv2d(3, ch[0], 4, 2, 1), nn.ReLU(),
            nn.Conv2d(ch[0], ch[1], 4, 2, 1), nn.ReLU(),
            nn.Conv2d(ch[1], ch[2], 4, 2, 1), nn.ReLU()
        )

        self.fc_mu = nn.Linear(ch[2]*4*4, latent_dim)
        self.fc_logvar = nn.Linear(ch[2]*4*4, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, ch[2]*4*4)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(ch[2], ch[1], 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(ch[1], ch[0], 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(ch[0], 3, 4, 2, 1), nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x).view(x.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z).view(-1, config["encoder_channels"][-1], 4, 4)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [9]:
def vae_loss(x_hat, x, mu, logvar):
    recon = F.mse_loss(x_hat, x, reduction="sum")
    kl = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kl

In [None]:
model = VAE(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

wandb.watch(model, log="gradients", log_freq=500)

fid_scores = []
kid_scores = []
epochs_list = []

for epoch in range(1, config["epochs"] + 1):
    model.train()
    epoch_start = time.time()

    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0
    latent_mu_mean = 0.0
    latent_mu_std = 0.0

    for x, _ in tqdm(train_loader, desc=f"Epoch {epoch}"):
        x = x.to(device)

        optimizer.zero_grad()
        x_hat, mu, logvar = model(x)

        recon = F.mse_loss(x_hat, x, reduction="sum")
        kl = -0.5 * torch.sum(
            1 + logvar - mu.pow(2) - logvar.exp()
        )

        loss = recon + config["beta"] * kl
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon.item()
        total_kl += kl.item()

        latent_mu_mean += mu.mean().item()
        latent_mu_std += mu.std().item()

    avg_loss = total_loss / len(train_set)
    avg_recon = total_recon / len(train_set)
    avg_kl = total_kl / len(train_set)
    latent_mu_mean /= len(train_loader)
    latent_mu_std /= len(train_loader)

    epoch_time = time.time() - epoch_start

    wandb.log({
        "epoch": epoch,
        "train/total_loss": avg_loss,
        "train/recon_loss": avg_recon,
        "train/kl_loss": avg_kl,
        "latent/mu_mean": latent_mu_mean,
        "latent/mu_std": latent_mu_std,
        "time/epoch_sec": epoch_time,
        "beta": config["beta"],
    })

    print(
        f"Epoch {epoch} | "
        f"Loss: {avg_loss:.4f} | "
        f"Recon: {avg_recon:.4f} | "
        f"KL: {avg_kl:.4f}"
    )

    if epoch % config["fid_every"] == 0:
        model.eval()

        if os.path.exists(gen_folder):
            shutil.rmtree(gen_folder)
        os.makedirs(gen_folder, exist_ok=True)

        with torch.no_grad():
            n_samples = config["fid_samples"]
            batch_size = 64
            idx = 0

            for _ in range(n_samples // batch_size):
                z = torch.randn(
                    batch_size,
                    config["latent_dim"],
                    device=device
                )

                imgs = model.decode(z)
                imgs = (imgs + 1) / 2
                imgs = imgs.clamp(0, 1)

                for j in range(imgs.size(0)):
                    vutils.save_image(
                        imgs[j],
                        os.path.join(gen_folder, f"{idx}.png")
                    )
                    idx += 1

        metrics = calculate_metrics(
            input1=gen_folder,
            input2=real_folder,
            fid=True,
            kid=True,
            kid_subset_size=config["kid_subset_size"]
        )

        fid = metrics["frechet_inception_distance"]
        kid = metrics["kernel_inception_distance_mean"]

        fid_scores.append(fid)
        kid_scores.append(kid)
        epochs_list.append(epoch)

        wandb.log({
            "metrics/FID": fid,
            "metrics/KID": kid,
        })

        if config["log_images"]:
            with torch.no_grad():
                z = torch.randn(
                    config["num_log_images"],
                    config["latent_dim"],
                    device=device
                )
                samples = model.decode(z)
                samples = (samples + 1) / 2
                samples = samples.clamp(0, 1)

            grid = vutils.make_grid(samples, nrow=4)
            wandb.log({
                "samples": wandb.Image(
                    grid.permute(1, 2, 0).cpu().numpy(),
                    caption=f"Epoch {epoch}"
                )
            })

        print(
            f"Epoch {epoch} | "
            f"FID: {fid:.2f} | "
            f"KID: {kid:.5f}"
        )

Epoch 1: 100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


Epoch 1 | Loss: 352.5398 | Recon: 330.7979 | KL: 86.9676


Epoch 2: 100%|██████████| 391/391 [00:15<00:00, 24.88it/s]


Epoch 2 | Loss: 201.9501 | Recon: 171.5416 | KL: 121.6337


Epoch 3: 100%|██████████| 391/391 [00:15<00:00, 26.05it/s]


Epoch 3 | Loss: 174.8085 | Recon: 143.4349 | KL: 125.4946


Epoch 4: 100%|██████████| 391/391 [00:14<00:00, 26.62it/s]


Epoch 4 | Loss: 160.0682 | Recon: 128.3359 | KL: 126.9292


Epoch 5: 100%|██████████| 391/391 [00:14<00:00, 27.25it/s]


Epoch 5 | Loss: 148.5971 | Recon: 116.0759 | KL: 130.0849


Creating feature extractor "inception-v3-compat" with features ['2048']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 267MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "/tmp/cifar10_gen" with extensions png,jpg,jpeg
Found 960 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples
Extracting features from input2
Looking for samples non-recursivelty in "/tmp/cifar10_real" with extensions png,jpg,jpeg
Found 10000 samples
Processing samples
Frechet Inception Distance: 250.81035202984603
Kernel Inception Distance: 0.25072011470794675 ± 0.005563293500981493


Epoch 5 | FID: 250.81 | KID: 0.25072


Epoch 6: 100%|██████████| 391/391 [00:16<00:00, 23.86it/s]


Epoch 6 | Loss: 141.1249 | Recon: 107.9993 | KL: 132.5023


Epoch 7: 100%|██████████| 391/391 [00:15<00:00, 24.99it/s]


Epoch 7 | Loss: 136.0787 | Recon: 102.4557 | KL: 134.4922


Epoch 8: 100%|██████████| 391/391 [00:15<00:00, 25.12it/s]


Epoch 8 | Loss: 132.2259 | Recon: 98.2369 | KL: 135.9559


Epoch 9: 100%|██████████| 391/391 [00:15<00:00, 25.23it/s]


Epoch 9 | Loss: 129.0459 | Recon: 94.5739 | KL: 137.8877


Epoch 10: 100%|██████████| 391/391 [00:15<00:00, 25.12it/s]


Epoch 10 | Loss: 126.4948 | Recon: 91.6288 | KL: 139.4640


Creating feature extractor "inception-v3-compat" with features ['2048']
Extracting features from input1
Looking for samples non-recursivelty in "/tmp/cifar10_gen" with extensions png,jpg,jpeg
Found 960 samples
Processing samples
Extracting features from input2
Looking for samples non-recursivelty in "/tmp/cifar10_real" with extensions png,jpg,jpeg
Found 10000 samples
Processing samples:  15%|█▍        | 1472/10000 [00:06<00:33, 256.75samples/s]

In [None]:
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(epochs_list, fid_scores, marker="o")
plt.title("FID over epochs")
plt.xlabel("Epoch")
plt.ylabel("FID")
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(epochs_list, kid_scores, marker="o")
plt.title("KID over epochs")
plt.xlabel("Epoch")
plt.ylabel("KID")
plt.grid(True)

plt.tight_layout()
plt.show()

wandb.log({
    "FID_KID_curves": wandb.Image(plt.gcf())
})

In [None]:
model.eval()
with torch.no_grad():
    z = torch.randn(16, config["latent_dim"]).to(device)
    samples = model.decode(z)
    samples = (samples + 1) / 2
    samples = samples.clamp(0, 1)

grid = vutils.make_grid(samples, nrow=4)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title("Final VAE Samples")
plt.show()

wandb.log({
    "final_samples": wandb.Image(
        grid.permute(1, 2, 0).cpu().numpy(),
        caption="Final VAE Samples"
    )
})

wandb.finish()