In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
import medmnist
from medmnist import ChestMNIST
import numpy as np
import os
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load MedMNIST Dataset (ChestMNIST as an example)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = ChestMNIST(root="./data", split="train", download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Using downloaded and verified file: ./data\chestmnist.npz


In [5]:
# Generator Model
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)

In [6]:
# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x.view(x.size(0), -1))

In [7]:
# WGAN-GP Gradient Penalty
def gradient_penalty(D, real_data, fake_data):
    alpha = torch.rand(real_data.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    d_interpolates = D(interpolates)
    grad_outputs = torch.ones_like(d_interpolates)
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates, grad_outputs=grad_outputs,
        create_graph=True, retain_graph=True)[0]
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

In [8]:
# Training Function
def train_gan(gan_type, num_epochs=50):
    writer = SummaryWriter(f"runs/{gan_type}")
    
    z_dim = 100
    generator = Generator(z_dim).to(device)
    discriminator = Discriminator().to(device)

    optim_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for real, _ in tqdm(train_loader):
            real = real.to(device)

            # Generate fake images
            z = torch.randn(real.size(0), z_dim).to(device)
            fake = generator(z)

            # Discriminator update
            optim_D.zero_grad()
            real_loss, fake_loss = 0, 0

            if gan_type == "LS-GAN":
                real_loss = 0.5 * ((discriminator(real) - 1) ** 2).mean()
                fake_loss = 0.5 * (discriminator(fake) ** 2).mean()
            elif gan_type == "WGAN":
                real_loss = -discriminator(real).mean()
                fake_loss = discriminator(fake).mean()
            elif gan_type == "WGAN-GP":
                real_loss = -discriminator(real).mean()
                fake_loss = discriminator(fake).mean()
                gp = gradient_penalty(discriminator, real, fake)
                loss_D = real_loss + fake_loss + 10 * gp
            else:
                raise ValueError("Invalid GAN type")

            loss_D = real_loss + fake_loss
            loss_D.backward()
            optim_D.step()

            # Generator update
            if epoch % 5 == 0:
                optim_G.zero_grad()
                fake = generator(z)
                loss_G = -discriminator(fake).mean() if gan_type in ["WGAN", "WGAN-GP"] else ((discriminator(fake) - 1) ** 2).mean()
                loss_G.backward()
                optim_G.step()

                # TensorBoard Logging
                writer.add_scalar("Loss/Discriminator", loss_D.item(), epoch)
                writer.add_scalar("Loss/Generator", loss_G.item(), epoch)

        # Save generated images
        vutils.save_image(fake[:25], f"generated/{gan_type}_epoch_{epoch}.png", normalize=True)

    torch.save(generator.state_dict(), f"models/{gan_type}_generator.pth")
    writer.close()

In [9]:
import os

# Ensure directories exist
os.makedirs("generated", exist_ok=True)
os.makedirs("models", exist_ok=True)

In [10]:
# Train all three GANs
for gan in ["LS-GAN", "WGAN", "WGAN-GP"]:
    train_gan(gan)

100%|██████████| 1227/1227 [00:11<00:00, 104.78it/s]
100%|██████████| 1227/1227 [00:09<00:00, 136.14it/s]
100%|██████████| 1227/1227 [00:08<00:00, 137.97it/s]
100%|██████████| 1227/1227 [00:09<00:00, 136.27it/s]
100%|██████████| 1227/1227 [00:08<00:00, 137.85it/s]
100%|██████████| 1227/1227 [00:10<00:00, 112.14it/s]
100%|██████████| 1227/1227 [00:09<00:00, 134.62it/s]
100%|██████████| 1227/1227 [00:08<00:00, 136.83it/s]
100%|██████████| 1227/1227 [00:08<00:00, 142.18it/s]
100%|██████████| 1227/1227 [00:08<00:00, 139.87it/s]
100%|██████████| 1227/1227 [00:11<00:00, 111.23it/s]
100%|██████████| 1227/1227 [00:08<00:00, 140.13it/s]
100%|██████████| 1227/1227 [00:08<00:00, 140.67it/s]
100%|██████████| 1227/1227 [00:08<00:00, 141.08it/s]
100%|██████████| 1227/1227 [00:08<00:00, 140.83it/s]
100%|██████████| 1227/1227 [00:10<00:00, 112.09it/s]
100%|██████████| 1227/1227 [00:08<00:00, 138.08it/s]
100%|██████████| 1227/1227 [00:08<00:00, 139.81it/s]
100%|██████████| 1227/1227 [00:08<00:00, 140.5

In [13]:
import torch_fidelity
print(torch_fidelity.__version__)


0.3.0


In [23]:
def compute_metrics():
    inception = InceptionScore().to(device)
    fid = FrechetInceptionDistance().to(device)

    # Load real images from ChestMNIST (using train_loader)
    real_images, _ = next(iter(train_loader))  # Use train_loader instead of real_data_loader
    real_images = real_images.to(device)
    real_images = (real_images * 255).byte()  # Convert to uint8
    real_images = real_images.repeat(1, 3, 1, 1)  # Convert grayscale to RGB if needed
    fid.update(real_images, real=True)  # Update FID with real images

    for gan in ["LS-GAN", "WGAN", "WGAN-GP"]:
        generator = Generator().to(device)
        generator.load_state_dict(torch.load(f"models/{gan}_generator.pth"))
        generator.eval()
        
        fake_images = torch.cat([generator(torch.randn(16, 100).to(device)) for _ in range(5)], dim=0)  # Reduce batch size
        fake_images = (fake_images + 1) / 2  # Rescale to [0,1]
        fake_images = (fake_images * 255).byte()  # Convert to uint8
        fake_images = fake_images.repeat(1, 3, 1, 1)  # Convert grayscale to RGB

        for i in range(0, fake_images.shape[0], 16):  # Process in smaller batches
            inception.update(fake_images[i : i + 16])
            fid.update(fake_images[i : i + 16], real=False)


        score, _ = inception.compute()
        fid_value = fid.compute()

        print(f"{gan} - Inception Score: {score.item()}, FID: {fid_value.item()}")



In [24]:
compute_metrics()

  generator.load_state_dict(torch.load(f"models/{gan}_generator.pth"))


LS-GAN - Inception Score: 1.714969277381897, FID: 344.26531982421875
WGAN - Inception Score: 2.0240354537963867, FID: 337.7819519042969
WGAN-GP - Inception Score: 1.8891353607177734, FID: 339.98663330078125
