In [2]:
!pip install torch torchvision matplotlib



Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
# Basic GAN Training on MNIST
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=100, num_classes=10):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, z_dim)
        self.model = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 784),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        x = noise * self.label_embedding(labels)
        return self.model(x).view(-1, 1, 28, 28)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, 784)
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        x = img.view(img.size(0), -1) * self.label_embedding(labels)
        return self.model(x)

G = Generator().to(device)
D = Discriminator().to(device)

# Loss and optimizers
loss_fn = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002)

# Training loop (keep short to meet exam limits)
epochs = 20
z_dim = 100

for epoch in range(epochs):
    for real_imgs, labels in dataloader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        labels = labels.to(device)

        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = G(noise, labels)
        D_real = D(real_imgs, labels)
        D_fake = D(fake_imgs.detach(), labels)

        loss_D = loss_fn(D_real, real) + loss_fn(D_fake, fake)
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        fake_imgs = G(noise, labels)
        D_fake = D(fake_imgs, labels)
        loss_G = loss_fn(D_fake, real)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

# Save the generator
torch.save(G.state_dict(), "mnist_generator.pth")


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.64MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.4MB/s]


Epoch 1/20, Loss D: 0.5417, Loss G: 1.2817
Epoch 2/20, Loss D: 0.4947, Loss G: 1.6226
Epoch 3/20, Loss D: 0.4836, Loss G: 1.8727
Epoch 4/20, Loss D: 0.4427, Loss G: 2.1556
Epoch 5/20, Loss D: 0.5337, Loss G: 2.0687
Epoch 6/20, Loss D: 0.6278, Loss G: 1.8968
Epoch 7/20, Loss D: 0.6394, Loss G: 1.9758
Epoch 8/20, Loss D: 0.5635, Loss G: 2.2148
Epoch 9/20, Loss D: 0.5555, Loss G: 1.9385
Epoch 10/20, Loss D: 0.6249, Loss G: 1.9853
Epoch 11/20, Loss D: 0.8294, Loss G: 1.7643
Epoch 12/20, Loss D: 0.7197, Loss G: 1.9804
Epoch 13/20, Loss D: 0.8216, Loss G: 1.7236
Epoch 14/20, Loss D: 0.9377, Loss G: 1.7347
Epoch 15/20, Loss D: 0.7095, Loss G: 1.8758
Epoch 16/20, Loss D: 1.0583, Loss G: 1.5537
Epoch 17/20, Loss D: 0.6972, Loss G: 2.0225
Epoch 18/20, Loss D: 0.7422, Loss G: 1.8167
Epoch 19/20, Loss D: 0.8202, Loss G: 1.6506
Epoch 20/20, Loss D: 0.7952, Loss G: 1.7157


In [4]:
from google.colab import files
files.download("mnist_generator.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>