In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
from collections import Counter


In [2]:
# ===== USER INPUTS =====
dataset_choice = "mnist"   # "mnist" or "fashion"
epochs = 30
batch_size = 64
noise_dim = 100
learning_rate = 0.0002
save_interval = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

if dataset_choice == "mnist":
    dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
elif dataset_choice == "fashion":
    dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
else:
    raise ValueError("Invalid dataset choice")

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 508kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.70MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.2MB/s]


In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(-1, 1, 28, 28)

generator = Generator().to(device)


In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

discriminator = Discriminator().to(device)


In [6]:
criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)


In [7]:
os.makedirs("generated_samples", exist_ok=True)

for epoch in range(1, epochs + 1):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        batch_size_curr = real_imgs.size(0)

        real_labels = torch.ones(batch_size_curr, 1).to(device)
        fake_labels = torch.zeros(batch_size_curr, 1).to(device)

        # ----- Train Discriminator -----
        optimizer_D.zero_grad()

        real_preds = discriminator(real_imgs)
        d_real_loss = criterion(real_preds, real_labels)

        noise = torch.randn(batch_size_curr, noise_dim).to(device)
        fake_imgs = generator(noise)
        fake_preds = discriminator(fake_imgs.detach())
        d_fake_loss = criterion(fake_preds, fake_labels)

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_D.step()

        d_acc = ((real_preds > 0.5).float().mean() +
                 (fake_preds < 0.5).float().mean()) / 2 * 100

        # ----- Train Generator -----
        optimizer_G.zero_grad()

        fake_preds = discriminator(fake_imgs)
        g_loss = criterion(fake_preds, real_labels)
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch}/{epochs} | D_loss: {d_loss.item():.2f} | D_acc: {d_acc:.2f}% | G_loss: {g_loss.item():.2f}")

    if epoch % save_interval == 0:
        save_image(fake_imgs[:25], f"generated_samples/epoch_{epoch:02d}.png", nrow=5, normalize=True)


Epoch 1/30 | D_loss: 0.14 | D_acc: 98.44% | G_loss: 3.19
Epoch 2/30 | D_loss: 1.65 | D_acc: 59.38% | G_loss: 2.24
Epoch 3/30 | D_loss: 1.02 | D_acc: 82.81% | G_loss: 2.41
Epoch 4/30 | D_loss: 0.18 | D_acc: 96.88% | G_loss: 3.19
Epoch 5/30 | D_loss: 0.66 | D_acc: 90.62% | G_loss: 2.31
Epoch 6/30 | D_loss: 0.47 | D_acc: 93.75% | G_loss: 4.20
Epoch 7/30 | D_loss: 0.53 | D_acc: 89.06% | G_loss: 1.86
Epoch 8/30 | D_loss: 0.58 | D_acc: 85.94% | G_loss: 2.33
Epoch 9/30 | D_loss: 0.69 | D_acc: 82.81% | G_loss: 2.70
Epoch 10/30 | D_loss: 0.49 | D_acc: 92.19% | G_loss: 2.94
Epoch 11/30 | D_loss: 0.35 | D_acc: 92.19% | G_loss: 2.64
Epoch 12/30 | D_loss: 0.55 | D_acc: 90.62% | G_loss: 2.41
Epoch 13/30 | D_loss: 0.63 | D_acc: 89.06% | G_loss: 3.87
Epoch 14/30 | D_loss: 0.68 | D_acc: 84.38% | G_loss: 2.59
Epoch 15/30 | D_loss: 0.47 | D_acc: 85.94% | G_loss: 2.81
Epoch 16/30 | D_loss: 0.72 | D_acc: 81.25% | G_loss: 1.64
Epoch 17/30 | D_loss: 0.89 | D_acc: 82.81% | G_loss: 2.47
Epoch 18/30 | D_loss: 0

In [8]:
os.makedirs("final_generated_images", exist_ok=True)

noise = torch.randn(100, noise_dim).to(device)
final_images = generator(noise).detach().cpu()

for i, img in enumerate(final_images):
    save_image(img, f"final_generated_images/img_{i}.png", normalize=True)

print("✅ 100 final images generated")


✅ 100 final images generated


In [9]:
from torchvision import models

classifier = models.resnet18(pretrained=True)
classifier.eval()

labels = []

for img in final_images:
    img = img.repeat(3,1,1).unsqueeze(0)
    with torch.no_grad():
        output = classifier(img)
        labels.append(output.argmax().item())

counts = Counter(labels)

print("\n--- Label Distribution ---")
for label, count in sorted(counts.items()):
    print(f"Label {label}: {count} images")




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 195MB/s]



--- Label Distribution ---
Label 49: 1 images
Label 58: 1 images
Label 111: 1 images
Label 327: 3 images
Label 403: 5 images
Label 458: 48 images
Label 500: 3 images
Label 510: 2 images
Label 528: 1 images
Label 685: 4 images
Label 761: 1 images
Label 772: 4 images
Label 815: 12 images
Label 847: 1 images
Label 878: 7 images
Label 913: 1 images
Label 922: 2 images
Label 971: 3 images
