In [28]:
import torch
import torch.nn as nn
import torch.optim as optim 
from torchvision import datasets, transforms 
from torchvision.utils import save_image, make_grid
import os
import matplotlib.pyplot as plt

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [29]:
batch_size = 128
z_dim = 100 
num_classes = 10
image_size = 28
channels = 1 
epochs = 50 
lr = 0.0002
beta1 = 0.5

os.makedirs('cgan_generated_images', exist_ok=True)

In [30]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size= batch_size, shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to .\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting .\MNIST\raw\train-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to .\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting .\MNIST\raw\train-labels-idx1-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to .\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting .\MNIST\raw\t10k-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to .\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting .\MNIST\raw\t10k-labels-idx1-ubyte.gz to .\MNIST\raw






In [31]:
class Generator(nn.Module): 
    def __init__(self, z_dim, num_classes, img_shape): 
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes 

        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256), # for smaller networks --> can be avoided for deeper nn
            nn.ReLU(True), 

            nn.Linear(256, 512), 
            nn.BatchNorm1d(512), 
            nn.ReLU(True), 

            nn.Linear(512, 1024), 
            nn.BatchNorm1d(1024), 
            nn.ReLU(True), 

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))), 
            nn.Tanh()
        )

    def forward(self, noise, labels): 
        x = torch.cat([noise, self.label_emb(labels)], dim=1)
        img = self.model(x)
        img = img.view(x.size(0), *self.img_shape)
        return img

In [32]:
class Discriminator(nn.Module): 
    def __init__(self, num_classes, img_shape): 
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Linear(512, 256),  
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels): 
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, self.label_emb(labels)], dim=1)
        return self.model(x)

In [33]:
img_shape = (channels, image_size, image_size)

generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr = lr, betas = (beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr, betas = (beta1, 0.999))

In [34]:
k, p = 3, 1 
k = 3 # genrator updates per iterations
p =1 # discriminator updates per iterations
# train generator more than discriminator 
for epoch in range(1, epochs+1) : 
    for i, (real_imgs, real_labels) in enumerate(train_loader) :
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = real_labels.to(device)

        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        for _ in range(p) : 
            z = torch.randn(batch_size_curr, z_dim, device=device)
            fake_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)

            with torch.no_grad(): 
                gen_imgs = generator(z, fake_labels)

            real_validity = discriminator(real_imgs, real_labels)
            d_real_loss = criterion(real_validity, real)

            fake_validity = discriminator(gen_imgs.detach(), fake_labels)
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
        
        for _ in range(k) : 
            z = torch.randn(batch_size_curr, z_dim, device=device)
            gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
            gen_imgs = generator(z, gen_labels)
            
            validity = discriminator(gen_imgs, gen_labels)
            g_loss = criterion(validity, real)

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
        
        if i%200 == 0: 
            print(
                f"[Epoch {epoch}/{epoch}] [Batch {i}/{len(train_loader)}]"
                f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}"
            )

    generator.eval()
    with torch.no_grad():
        z = torch.randn(10, z_dim, device=device)
        labels = torch.arange(0, 10, dtype=torch.long, device=device)
        samples = generator(z, labels)
        samples = samples*0.5 + 0.5
        save_image(samples, f"cgan_generated_images/epochs_{epoch}.png", nrow=10)
    generator.train()

[Epoch 1/1] [Batch 0/469]D Loss: 1.3728 | G Loss: 0.5697
[Epoch 1/1] [Batch 200/469]D Loss: 1.4121 | G Loss: 0.6444
[Epoch 1/1] [Batch 400/469]D Loss: 1.3732 | G Loss: 0.6681
[Epoch 2/2] [Batch 0/469]D Loss: 1.3854 | G Loss: 0.6868
[Epoch 2/2] [Batch 200/469]D Loss: 1.4259 | G Loss: 0.6584
[Epoch 2/2] [Batch 400/469]D Loss: 1.3649 | G Loss: 0.7046
[Epoch 3/3] [Batch 0/469]D Loss: 1.4129 | G Loss: 0.6790
[Epoch 3/3] [Batch 200/469]D Loss: 1.4306 | G Loss: 0.6403
[Epoch 3/3] [Batch 400/469]D Loss: 1.3900 | G Loss: 0.6813
[Epoch 4/4] [Batch 0/469]D Loss: 1.3551 | G Loss: 0.7254
[Epoch 4/4] [Batch 200/469]D Loss: 1.3966 | G Loss: 0.6772
[Epoch 4/4] [Batch 400/469]D Loss: 1.3877 | G Loss: 0.6901
[Epoch 5/5] [Batch 0/469]D Loss: 1.4105 | G Loss: 0.7227
[Epoch 5/5] [Batch 200/469]D Loss: 1.3695 | G Loss: 0.7134
[Epoch 5/5] [Batch 400/469]D Loss: 1.4026 | G Loss: 0.7026
[Epoch 6/6] [Batch 0/469]D Loss: 1.3827 | G Loss: 0.7157
[Epoch 6/6] [Batch 200/469]D Loss: 1.3590 | G Loss: 0.6844
[Epoch 6/

In [35]:
def generate_digit_images(generator, digit, num_samples = 16, save_path = None) : 
    generator.eval()
    z = torch.randn(num_samples, z_dim).to(device=device)
    labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)

    with torch.no_grad() : 
        gen_imgs = generator(z, labels)
        gen_imgs = gen_imgs*0.5 + 0.5

    if save_path : 
        save_image(gen_imgs, save_path, nrow=4)
        print(f"Saved to {save_path}")

    return gen_imgs

In [37]:
generate_digit_images(generator, digit=7, num_samples=16, save_path='cgan_generated_images/seven.png')

Saved to cgan_generated_images/seven.png


tensor([[[[3.5384e-04, 5.9098e-05, 9.1341e-04,  ..., 2.1955e-04,
           1.1718e-04, 8.7947e-05],
          [3.6332e-04, 2.9445e-05, 1.4573e-04,  ..., 9.9570e-05,
           1.0014e-05, 4.8280e-06],
          [1.0499e-04, 3.6680e-02, 1.4390e-03,  ..., 7.4506e-07,
           8.8811e-06, 3.6061e-05],
          ...,
          [3.5164e-04, 6.3897e-03, 1.5667e-03,  ..., 6.5681e-04,
           5.0575e-05, 1.5825e-05],
          [1.6093e-06, 2.3568e-04, 4.0731e-04,  ..., 1.7536e-04,
           5.0122e-04, 2.4549e-02],
          [1.4573e-05, 2.8413e-03, 2.9802e-07,  ..., 8.4089e-02,
           1.5031e-03, 6.1452e-05]]],


        [[[3.2851e-04, 5.4032e-05, 1.0688e-03,  ..., 2.6745e-04,
           1.2723e-04, 8.8692e-05],
          [3.6117e-04, 2.8998e-05, 1.5020e-04,  ..., 9.0867e-05,
           1.1295e-05, 4.4703e-06],
          [1.0625e-04, 3.6351e-02, 1.5407e-03,  ..., 8.3447e-07,
           8.4937e-06, 4.4882e-05],
          ...,
          [4.2653e-04, 8.1497e-03, 1.6486e-03,  ..., 8.53