In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import torchvision
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [2]:
# Hyperparameters
batch_size = 128
z_dim = 100
image_size = 28
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5 # Adam optimizer beta1
# beta2 = 0.999

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Devide : {device}")

# Create output folder
# if not os.path.exists("generated_imgs"):
os.makedirs("generated_imgs", exist_ok=True)

Devide : cuda


In [3]:
# Transform: Normalize images between [-1, 1] (because Tanh will be used as output)
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))    # Normalize between [-1, 1]
])

# Load MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 477kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.37MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.73MB/s]


In [4]:
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: (N, z_dim, 1, 1)
            nn.ConvTranspose2d(z_dim, 256, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
    def forward(self, z):
        return self.net(z)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # Input: (N, 1, 28, 28)
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.net(img)

In [7]:
# Models
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss - Binary Cross Entropy Loss
criterion = nn.BCELoss()

In [8]:
def generate_and_save_images(epoch):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1).to(device)
        fake_images = generator(z)
        fake_images = fake_images * 0.5 + 0.5  # Denormalize to [0,1]
        save_image(fake_images, f"generated_imgs/sample_epoch_{epoch}.png", nrow=8)
    generator.train()

In [9]:
k = 3   # Generator updates per iteration
p = 1   # Discriminator updates per iterations

In [11]:
# Training Loop
batch_size_curr = batch_size
for epoch in range(1, epochs+1):
    for i, (real_imgs, _) in enumerate(train_loader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)

        # Use the actual batch size for creating target tensors
        real = torch.ones(batch_size, 1, device=device)
        fake = torch.zeros(batch_size, 1, device=device)

        ### ----- Train Discriminator p times ----- ###
        for underscore in range(p):
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            # Real
            real_validity = discriminator(real_imgs)
            d_real_loss = criterion(real_validity, real)

            # Fake
            fake_validity = discriminator(fake_imgs.detach())
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
        ### ----- Train Generator k times ----- ###
        for _ in range(k):
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            validity = discriminator(fake_imgs)
            g_loss = criterion(validity, real)  # fool D -> label as real

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
        if i % 200 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step or Batch [{i}/{len(train_loader)}], "
                    f"D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}")
    # Save sample images
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1).to(device)
        samples = generator(z)
        samples = samples * 0.5 + 0.5  # Denormalize to [0,1]
        save_image(samples, f"generated_imgs/epoch_{epoch}.png", nrow=8)
    generator.train()

Epoch [1/50], Step or Batch [0/469], D_loss: 1.3600 | G_loss: 0.5878
Epoch [1/50], Step or Batch [200/469], D_loss: 1.2518 | G_loss: 0.7615
Epoch [1/50], Step or Batch [400/469], D_loss: 1.2640 | G_loss: 0.9243
Epoch [2/50], Step or Batch [0/469], D_loss: 1.2030 | G_loss: 1.0682
Epoch [2/50], Step or Batch [200/469], D_loss: 1.1677 | G_loss: 0.8792
Epoch [2/50], Step or Batch [400/469], D_loss: 1.2151 | G_loss: 0.8165
Epoch [3/50], Step or Batch [0/469], D_loss: 1.1769 | G_loss: 0.7157
Epoch [3/50], Step or Batch [200/469], D_loss: 1.1885 | G_loss: 0.9161
Epoch [3/50], Step or Batch [400/469], D_loss: 1.1251 | G_loss: 0.8118
Epoch [4/50], Step or Batch [0/469], D_loss: 1.1756 | G_loss: 0.7406
Epoch [4/50], Step or Batch [200/469], D_loss: 1.2145 | G_loss: 0.7074
Epoch [4/50], Step or Batch [400/469], D_loss: 1.0967 | G_loss: 0.9615
Epoch [5/50], Step or Batch [0/469], D_loss: 1.1558 | G_loss: 0.8886
Epoch [5/50], Step or Batch [200/469], D_loss: 1.1462 | G_loss: 0.5883
Epoch [5/50], St