In [10]:
import torch
import torch.nn as nn
import torch.optim as optim 
from torchvision import datasets, transforms 
from torchvision.utils import save_image, make_grid
import os
import matplotlib.pyplot as plt

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [11]:
batch_size = 128
z_dim = 100 
num_classes = 10
image_size = 28
channels = 1 
epochs = 50 
lr = 0.0002
beta1 = 0.5

os.makedirs('cgan_generated_images', exist_ok=True)

In [12]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size= batch_size, shuffle=True
)

In [23]:
class Generator(nn.Module): 
    def __init__(self, z_dim, num_classes, img_shape): 
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes 

        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256), # for smaller networks --> can be avoided for deeper nn
            nn.ReLU(True), 

            nn.Linear(256, 512), 
            nn.BatchNorm1d(512), 
            nn.ReLU(True), 

            nn.Linear(512, 1024), 
            nn.BatchNorm1d(1024), 
            nn.ReLU(True), 

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))), 
            nn.Tanh()
        )

    def forward(self, noise, labels): 
        x = torch.cat([noise, self.label_emb(labels)], dim=1)
        img = self.model(x)
        img = img.view(x.size(0), *self.img_shape)
        return img

In [24]:
class Discriminator(nn.Module): 
    def __init__(self, num_classes, img_shape): 
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Linear(512, 256),  
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels): 
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, self.label_emb(labels)], dim=1)
        return self.model(x)

In [25]:
img_shape = (channels, image_size, image_size)

generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr = lr, betas = (beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr, betas = (beta1, 0.999))

In [None]:
k, p = 3, 1 
k = 3 # genrator updates per iterations
p =1 # discriminator updates per iterations
# train generator more than discriminator 
for epoch in range(1, epochs+1) : 
    for i, (real_imgs, real_labels) in enumerate(train_loader) :
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = real_labels.to(device)

        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        for _ in range(p) : 
            z = torch.randn(batch_size_curr, z_dim, device=device)
            fake_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)

            with torch.no_grad(): 
                gen_imgs = generator(z, fake_labels)

            real_validity = discriminator(real_imgs, real_labels)
            d_real_loss = criterion(real_validity, real)

            fake_validity = discriminator(gen_imgs.detach(), fake_labels)
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
        
        for _ in range(k) : 
            z = torch.randn(batch_size_curr, z_dim, device=device)
            gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
            gen_imgs = generator(z, gen_labels)
            
            validity = discriminator(gen_imgs, gen_labels)
            g_loss = criterion(validity, real)

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
        
        if i%200 == 0: 
            print(
                f"[Epoch {epoch}/{epoch}] [Batch {i}/{len(train_loader)}]"
                f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}"
            )

    generator.eval()
    with torch.no_grad():
        z = torch.randn(10, z_dim, device=device)
        labels = torch.arange(0, 10, dtype=torch.long, device=device)
        samples = generator(z, labels)
        samples = samples*0.5 + 0.5
        save_image(samples, f"cgan_generated_images/epochs_{epoch}.png", nrow=10)
    generator.train()

[Epoch 1/1] [Batch 0/469]D Loss: 1.3781 | G Loss: 0.6357
[Epoch 1/1] [Batch 200/469]D Loss: 0.2826 | G Loss: 2.3706
[Epoch 1/1] [Batch 400/469]D Loss: 0.0029 | G Loss: 7.0090
[Epoch 2/2] [Batch 0/469]D Loss: 3.7348 | G Loss: 2.0120
[Epoch 2/2] [Batch 200/469]D Loss: 1.3512 | G Loss: 0.6483
[Epoch 2/2] [Batch 400/469]D Loss: 1.3867 | G Loss: 0.6621
[Epoch 3/3] [Batch 0/469]D Loss: 1.3872 | G Loss: 0.6745
[Epoch 3/3] [Batch 200/469]D Loss: 1.3726 | G Loss: 0.6786
[Epoch 3/3] [Batch 400/469]D Loss: 1.3758 | G Loss: 0.7344
[Epoch 4/4] [Batch 0/469]D Loss: 1.3582 | G Loss: 0.7013
[Epoch 4/4] [Batch 200/469]D Loss: 1.3823 | G Loss: 0.7081
[Epoch 4/4] [Batch 400/469]D Loss: 1.3656 | G Loss: 0.7040
[Epoch 5/5] [Batch 0/469]D Loss: 1.3605 | G Loss: 0.7073
[Epoch 5/5] [Batch 200/469]D Loss: 1.3533 | G Loss: 0.7350
[Epoch 5/5] [Batch 400/469]D Loss: 1.3806 | G Loss: 0.7210
[Epoch 6/6] [Batch 0/469]D Loss: 1.3680 | G Loss: 0.7064
[Epoch 6/6] [Batch 200/469]D Loss: 1.3796 | G Loss: 0.7132
[Epoch 6/

In [None]:
def generate_digit_images(generator, digit, num_samples = 16, save_path = None) : 
    generator.eval()
    z = torch.randn(num_samples, z_dim).to(device=device)
    labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)

    with torch.no_grad() : 
        gen_imgs = generator(z, labels)
        gen_imgs = gen_imgs*0.5 + 0.5

    if save_path : 
        save_image(gen_imgs, save_path, nrow=4)
        print(f"Saved to {save_path}")

    return gen_imgs

In [None]:
generate_digit_images(generator, digit=7, num_samples=16, save_path='cgan_generated/seven.png')