In [1]:
import torch, torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LATENT  = 100
BATCH   = 128
EPOCHS  = 30          # stops in ~25 min on a T4
SMOOTH  = 0.9         # real-label smoothing


In [2]:
tfm = transforms.Compose([transforms.ToTensor(),
                          transforms.Normalize((0.5,), (0.5,))])
train_ds = datasets.MNIST(root="data", train=True, download=True, transform=tfm)
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2)


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.18MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.67MB/s]


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(10, 50)
        self.fc = nn.Sequential(
            nn.Linear(LATENT + 50, 128*7*7),
            nn.BatchNorm1d(128*7*7), nn.ReLU(True))
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 7→14
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),    # 14→28
            nn.Tanh())
    def forward(self, z, y):
        x = torch.cat([z, self.embed(y)], 1)      # concat noise + label-emb
        x = self.fc(x).view(-1, 128, 7, 7)
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(10, 28*28)
        self.conv = nn.Sequential(
            nn.Conv2d(2, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Flatten(),
            nn.Linear(128*7*7, 1))
    def forward(self, x, y):
        yimg = self.embed(y).view(-1, 1, 28, 28)
        x = torch.cat([x, yimg], 1)               # image-concat label mask
        return self.conv(x)


In [4]:
G, D = Generator().to(DEVICE), Discriminator().to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
opt_G   = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D   = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))


In [5]:
for epoch in range(EPOCHS):
    for real, labels in tqdm(train_dl, leave=False):
        real, labels = real.to(DEVICE), labels.to(DEVICE)
        bs = real.size(0)

        # ---- Train Discriminator ----
        noise  = torch.randn(bs, LATENT, device=DEVICE)
        fake   = G(noise, labels)
        real_y = torch.full((bs,1), SMOOTH,  device=DEVICE)  # 0.9
        fake_y = torch.zeros(bs,1, device=DEVICE)            # 0.0

        D_real = D(real, labels)
        D_fake = D(fake.detach(), labels)
        loss_D = loss_fn(D_real, real_y) + loss_fn(D_fake, fake_y)

        opt_D.zero_grad();  loss_D.backward();  opt_D.step()

        # ---- Train Generator ----
        trick = torch.ones(bs,1, device=DEVICE)              # wants D to say 'real'
        D_fake_for_G = D(fake, labels)
        loss_G = loss_fn(D_fake_for_G, trick)

        opt_G.zero_grad();  loss_G.backward();  opt_G.step()

    print(f"Epoch {epoch+1}/{EPOCHS}  |  D: {loss_D.item():.3f}  G: {loss_G.item():.3f}")




Epoch 1/30  |  D: 0.920  G: 1.637




Epoch 2/30  |  D: 0.927  G: 0.985




Epoch 3/30  |  D: 0.927  G: 1.428




Epoch 4/30  |  D: 1.013  G: 1.946




Epoch 5/30  |  D: 1.016  G: 1.832




Epoch 6/30  |  D: 1.121  G: 1.815




Epoch 7/30  |  D: 1.038  G: 1.091




Epoch 8/30  |  D: 1.073  G: 1.075




Epoch 9/30  |  D: 0.983  G: 1.126




Epoch 10/30  |  D: 0.977  G: 1.419




Epoch 11/30  |  D: 1.175  G: 0.725




Epoch 12/30  |  D: 1.032  G: 1.228




Epoch 13/30  |  D: 1.077  G: 0.882




Epoch 14/30  |  D: 1.191  G: 1.126




Epoch 15/30  |  D: 1.202  G: 1.705




Epoch 16/30  |  D: 1.009  G: 1.652




Epoch 17/30  |  D: 1.241  G: 1.257




Epoch 18/30  |  D: 1.037  G: 1.332




Epoch 19/30  |  D: 1.123  G: 1.760




Epoch 20/30  |  D: 0.993  G: 1.947




Epoch 21/30  |  D: 1.072  G: 1.320




Epoch 22/30  |  D: 1.077  G: 0.950




Epoch 23/30  |  D: 1.106  G: 1.490




Epoch 24/30  |  D: 1.096  G: 1.339




Epoch 25/30  |  D: 1.081  G: 1.098




Epoch 26/30  |  D: 0.937  G: 1.667




Epoch 27/30  |  D: 1.048  G: 1.152




Epoch 28/30  |  D: 1.229  G: 2.233




Epoch 29/30  |  D: 0.924  G: 1.611


                                                 

Epoch 30/30  |  D: 0.971  G: 1.526




In [6]:
torch.save(G.state_dict(), "cgan_mnist_G.pth")     # ~5 MB
