In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#Data loading and preprocessing

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=128, shuffle=True)


#Models definitions

##Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )
    def forward(self, x):
        return self.net(x)



##Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)


#Training

In [None]:
z_dim = 100
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)
lr = 0.0002

criterion = nn.BCELoss()
optimizer_gen = optim.Adam(generator.parameters(), lr=lr)
optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr)

num_epochs = 50  # Change as needed

for epoch in range(num_epochs):
    for real, _ in dataloader:
        real = real.view(-1, 28*28).to(device)
        batch_size = real.size(0)

        # Labels for real and fake images
        label_real = torch.ones(batch_size, 1, device=device)
        label_fake = torch.zeros(batch_size, 1, device=device)

        # Train Discriminator
        noise = torch.randn(batch_size, z_dim, device=device)
        fake = generator(noise)
        loss_disc_real = criterion(discriminator(real), label_real)
        loss_disc_fake = criterion(discriminator(fake.detach()), label_fake)
        loss_disc = (loss_disc_real + loss_disc_fake) / 2

        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()

        # Train Generator
        output = discriminator(fake)
        loss_gen = criterion(output, label_real)

        optimizer_gen.zero_grad()
        loss_gen.backward()
        optimizer_gen.step()

    print(f"Epoch {epoch}: Loss_D {loss_disc.item()}, Loss_G {loss_gen.item()}")

# Plot losses...

#Sampling and visualization using the trained generator

In [None]:
n_images=16

noise = torch.randn(n_images, z_dim, device=device)
gen_imgs = generator(noise).detach().cpu().numpy()
gen_imgs = gen_imgs.reshape(n_images, 28, 28)
fig, axes = plt.subplots(1, n_images, figsize=(n_images, 1))
for img, ax in zip(gen_imgs, axes):
    ax.imshow((img + 1) / 2, cmap='gray')
    ax.axis('off')
plt.show()


# Calculating Inception Score (IS) and Frèchet Inception Distance (FID)

In [None]:
!pip install torchmetrics

In [None]:
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance

gen_images = #...
real_images = #...

is_metric = InceptionScore()
is_score, is_std = is_metric(gen_images)  # generated_images: [N, 3, 299, 299] tensor
print('Inception Score:', is_score.item())

fid_metric = FrechetInceptionDistance()
fid_metric.update(real_images, real=True)
fid_metric.update(gen_images, real=False)
fid_score = fid_metric.compute()  # Lower is better
print('FID:', fid_score.item())


In [None]:
!pip install torcheval

In [None]:
from torcheval.metrics import InceptionScore, FrechetInceptionDistance

gen_images = #...
real_images = #...

is_metric = InceptionScore(device=device)
is_score, is_std = is_metric(gen_images)  # generated_images: [N, 3, 299, 299] tensor
print('Inception Score: ', is_score.item())

fid_metric = FrechetInceptionDistance(device=device)
fid_metric.update(real_images, real=True)
fid_metric.update(gen_images, real=False)
fid_score = fid_metric.compute()  # Lower is better
print('FID: ', fid_score.item())