In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import torchvision
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [2]:
# Hyperparameters
batch_size = 128
z_dim = 100
image_size = 28
num_classes = 10
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5 # Adam optimizer beta1
# beta2 = 0.999

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Devide : {device}")

# Create output folder
# if not os.path.exists("generated_imgs"):
os.makedirs("c_gan_generated", exist_ok=True)

Devide : cuda


In [3]:
# Transform: Normalize images between [-1, 1] (because Tanh will be used as output)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))    # Normalize between [-1, 1]
])

# Load MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:02<00:00, 4.54MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.23MB/s]


In [4]:
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_shape):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes

        self.model = nn.Sequential(
            # Input: (N, z_dim, 1, 1)
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh(),
        )
    def forward(self, noise, labels):
        # Concatenate noise and label embedding
        x = torch.cat([noise, self.label_emb(labels)], dim=1)
        img = self.model(x)
        img = img.view(x.size(0), *self.img_shape)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes

        self.model = nn.Sequential(
            # Input: (N, 1, 28, 28)
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img, labels):
        # Flatten image and concatenate label
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, self.label_emb(labels)], dim=1)
        validity = self.model(x)
        return validity

In [7]:
img_shape = (channels, image_size, image_size)
# Models
generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss - Binary Cross Entropy Loss
criterion = nn.BCELoss()

In [8]:
def generate_and_save_images(epoch):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1).to(device)
        fake_images = generator(z)
        fake_images = fake_images * 0.5 + 0.5  # Denormalize to [0,1]
        save_image(fake_images, f"generated_imgs/sample_epoch_{epoch}.png", nrow=8)
    generator.train()

In [9]:
k = 3   # Generator updates per iteration
p = 1   # Discriminator updates per iterations

In [10]:
# Training Loop
for epoch in range(1, epochs+1):
    for i, (real_imgs, real_labels) in enumerate(train_loader):
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = real_labels.to(device)

        # Use the actual batch size for creating target tensors
        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        ### ----- Train Discriminator p times ----- ###
        for _ in range(p):
            z = torch.randn(batch_size_curr, z_dim, device=device)
            fake_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
            with torch.no_grad():
                gen_imgs = generator(z, fake_labels)

            # Real
            real_validity = discriminator(real_imgs, real_labels)
            d_real_loss = criterion(real_validity, real)

            # Fake
            fake_validity = discriminator(gen_imgs.detach(), fake_labels)
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
        ### ----- Train Generator k times ----- ###
        for _ in range(k):
            # Use batch_size_curr for generator input as well
            z = torch.randn(batch_size_curr, z_dim, device=device)
            gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
            gen_imgs = generator(z, gen_labels)

            validity = discriminator(gen_imgs, gen_labels)
            g_loss = criterion(validity, real)  # fool D -> label as real

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
        if i % 200 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step or Batch [{i}/{len(train_loader)}], "
                    f"D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}")
    # Save sample images
    generator.eval()
    with torch.no_grad():
        # Use a fixed size (e.g., 10) for generating samples for visualization,
        # or match the batch size if needed for consistency
        z = torch.randn(10, z_dim, device=device)
        labels = torch.arange(0, 10, dtype=torch.long, device=device)
        samples = generator(z, labels)
        samples = samples * 0.5 + 0.5  # Denormalize to [0,1]
        save_image(samples, f"c_gan_generated/epoch_{epoch}.png", nrow=10)
    generator.train()

Epoch [1/50], Step or Batch [0/469], D_loss: 1.3591 | G_loss: 0.6147
Epoch [1/50], Step or Batch [200/469], D_loss: 1.3742 | G_loss: 0.6974
Epoch [1/50], Step or Batch [400/469], D_loss: 1.3698 | G_loss: 0.6760
Epoch [2/50], Step or Batch [0/469], D_loss: 1.3789 | G_loss: 0.6628
Epoch [2/50], Step or Batch [200/469], D_loss: 1.4065 | G_loss: 0.6738
Epoch [2/50], Step or Batch [400/469], D_loss: 1.4194 | G_loss: 0.6756
Epoch [3/50], Step or Batch [0/469], D_loss: 1.4000 | G_loss: 0.6753
Epoch [3/50], Step or Batch [200/469], D_loss: 1.3900 | G_loss: 0.7060
Epoch [3/50], Step or Batch [400/469], D_loss: 1.3601 | G_loss: 0.6967
Epoch [4/50], Step or Batch [0/469], D_loss: 1.3801 | G_loss: 0.6921
Epoch [4/50], Step or Batch [200/469], D_loss: 1.3689 | G_loss: 0.7365
Epoch [4/50], Step or Batch [400/469], D_loss: 1.3525 | G_loss: 0.7117
Epoch [5/50], Step or Batch [0/469], D_loss: 1.3951 | G_loss: 0.6999
Epoch [5/50], Step or Batch [200/469], D_loss: 1.3725 | G_loss: 0.7066
Epoch [5/50], St

In [11]:
def generate_digit_images(generator, digit, num_samples=16, save_path=None):
    generator.eval()
    # with torch.no_grad():
    z = torch.randn(num_samples, z_dim, device=device).to(device)
    labels = torch.full((num_samples,), digit, device=device, dtype=torch.long).to(device)
    with torch.no_grad():
        gen_imgs = generator(z, labels)
        gen_imgs = gen_imgs * 0.5 + 0.5
    if save_path:
        save_image(gen_imgs, save_path, nrow=4)
        print(f"Generated images saved to {save_path}")
    return gen_imgs

In [12]:
generate_digit_images(generator, digit=5, num_samples=16, save_path="cgan_generated_imgs/digit_5.png")

FileNotFoundError: [Errno 2] No such file or directory: 'cgan_generated_imgs/digit_5.png'