In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # For 3-channel RGB images
])

# Load CIFAR-10
dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2
)


In [3]:
# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3):
        super().__init__()

        self.model = nn.Sequential(
            # Start: 100-dims -> 4x4x512 feature map
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # 4x4x512 -> 8x8x256
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # 8x8x256 -> 16x16x128
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # 16x16x128 -> 32x32x3
            nn.ConvTranspose2d(128, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        img = self.model(z)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super().__init__()

        self.model = nn.Sequential(
            # 32x32x3 -> 16x16x128
            nn.Conv2d(img_channels, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            # 16x16x128 -> 8x8x256
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 8x8x256 -> 4x4x512
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 4x4x512 -> 1x1x1
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, img):
        output = self.model(img)
        return output


In [4]:
latent_dim = 100
img_channels = 3
G = Generator(latent_dim, img_channels).to(device)
D = Discriminator(img_channels).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
lr = 0.0002
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))


In [5]:
os.makedirs('output', exist_ok=True)
fixed_noise = torch.randn(16, latent_dim, device=device)

n_epochs = 50
for epoch in range(n_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Discriminator: Maximize log(D(x)) + log(1 - D(G(z)))
        D.zero_grad()
        label_real = torch.ones(batch_size, 1, device=device)
        output_real = D(real_imgs)
        loss_D_real = criterion(output_real, label_real)

        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = G(noise)
        label_fake = torch.zeros(batch_size, 1, device=device)
        output_fake = D(fake_imgs.detach())
        loss_D_fake = criterion(output_fake, label_fake)

        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        opt_D.step()

        # Generator: Minimize log(1 - D(G(z))) <=> Maximize log(D(G(z)))
        G.zero_grad()
        label_real.fill_(1)  # Trick generator to think its doing better, targets are "1"
        output_fake = D(fake_imgs)
        loss_G = criterion(output_fake, label_real)
        loss_G.backward()
        opt_G.step()

    # Save generated images and print progress
    if epoch % 5 == 0:
        with torch.no_grad():
            fake = G(fixed_noise).detach().cpu()
            save_image(fake / 2 + 0.5, f'output/epoch_{epoch}.png', nrow=4)
            print(f'Epoch {epoch} | Loss_D: {loss_D.item():.4f} | Loss_G: {loss_G.item():.4f}')


Epoch 0 | Loss_D: 0.5934 | Loss_G: 5.1711
Epoch 5 | Loss_D: 0.6531 | Loss_G: 1.9687
Epoch 10 | Loss_D: 0.6235 | Loss_G: 3.4140
Epoch 15 | Loss_D: 0.2120 | Loss_G: 2.8819
Epoch 20 | Loss_D: 0.1533 | Loss_G: 4.5150
Epoch 25 | Loss_D: 0.0867 | Loss_G: 3.7201
Epoch 30 | Loss_D: 0.1291 | Loss_G: 3.9520
Epoch 35 | Loss_D: 0.3536 | Loss_G: 5.9168
Epoch 40 | Loss_D: 0.1392 | Loss_G: 4.2727
Epoch 45 | Loss_D: 0.1662 | Loss_G: 5.1184


In [6]:
print(device)

cuda
