In [4]:
# Imports
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


# Load MNIST Dataset

In [5]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)


# Generator and Discriminator

In [6]:

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(20 + 10, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        x = torch.cat([z, labels], dim=1)
        return self.fc(x).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(784 + 10, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        x = torch.cat([img.view(img.size(0), -1), labels], dim=1)
        return self.fc(x)


# Setup

In [7]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator().to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
g_opt = torch.optim.Adam(G.parameters(), lr=0.0002)
d_opt = torch.optim.Adam(D.parameters(), lr=0.0002)

def one_hot(labels, num_classes=10):
    return torch.nn.functional.one_hot(labels, num_classes).float()


# Training Loop

In [8]:

epochs = 20

for epoch in range(epochs):
    for real_imgs, labels in train_loader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        labels_oh = one_hot(labels).to(device)

        # Train Discriminator
        z = torch.randn(batch_size, 20).to(device)
        fake_imgs = G(z, labels_oh)

        real_targets = torch.ones(batch_size, 1).to(device)
        fake_targets = torch.zeros(batch_size, 1).to(device)

        d_real = D(real_imgs, labels_oh)
        d_fake = D(fake_imgs.detach(), labels_oh)

        d_loss = criterion(d_real, real_targets) + criterion(d_fake, fake_targets)
        D.zero_grad()
        d_loss.backward()
        d_opt.step()

        # Train Generator
        z = torch.randn(batch_size, 20).to(device)
        fake_imgs = G(z, labels_oh)
        d_fake = D(fake_imgs, labels_oh)
        g_loss = criterion(d_fake, real_targets)

        G.zero_grad()
        g_loss.backward()
        g_opt.step()

    print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")


Epoch 1/20 | D Loss: 0.3240 | G Loss: 1.6435
Epoch 2/20 | D Loss: 0.5577 | G Loss: 1.3676
Epoch 3/20 | D Loss: 0.9451 | G Loss: 1.0842
Epoch 4/20 | D Loss: 1.2069 | G Loss: 0.9357
Epoch 5/20 | D Loss: 0.6683 | G Loss: 1.5308
Epoch 6/20 | D Loss: 0.7360 | G Loss: 1.3629
Epoch 7/20 | D Loss: 1.6264 | G Loss: 0.6346
Epoch 8/20 | D Loss: 1.1611 | G Loss: 0.8632
Epoch 9/20 | D Loss: 1.2017 | G Loss: 0.9393
Epoch 10/20 | D Loss: 0.9213 | G Loss: 1.2468
Epoch 11/20 | D Loss: 1.3102 | G Loss: 0.8191
Epoch 12/20 | D Loss: 0.9992 | G Loss: 1.0873
Epoch 13/20 | D Loss: 1.3433 | G Loss: 0.8450
Epoch 14/20 | D Loss: 1.1037 | G Loss: 1.1446
Epoch 15/20 | D Loss: 0.8263 | G Loss: 1.3187
Epoch 16/20 | D Loss: 0.9621 | G Loss: 1.0154
Epoch 17/20 | D Loss: 1.1934 | G Loss: 0.8922
Epoch 18/20 | D Loss: 1.2365 | G Loss: 0.8237
Epoch 19/20 | D Loss: 0.9887 | G Loss: 1.1300
Epoch 20/20 | D Loss: 1.1443 | G Loss: 1.4128


In [9]:
# Save Generator Model
torch.save(G.state_dict(), "digit_generator.pth")
print("Model saved as digit_generator.pth")


Model saved as digit_generator.pth
