In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os

In [None]:
# Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

z_dim = 100
batch_size = 128
epochs = 20
lr = 5e-4
n_critic = 1
clip_value = 0.01
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)

os.makedirs("wgan_dcgan", exist_ok=True)

In [None]:
# Data Loader
# ----------------------------
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Rescale to [-1, 1]
])

dataloader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

In [None]:
# Generator
# ----------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = img_size // 4  # 7x7
        self.fc = nn.Linear(z_dim, 128 * self.init_size * self.init_size)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 7→14
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),    # 14→28
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 128, self.init_size, self.init_size)
        return self.conv_blocks(x)

In [None]:
# Critic
# ----------------------------
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # 28→14
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 14→7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1)  # Output is a raw score
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize Models and Optimizers
# ----------------------------
generator = Generator().to(device)
critic = Critic().to(device)

optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_C = optim.RMSprop(critic.parameters(), lr=lr)

In [None]:
# Training Loop
# ----------------------------
for epoch in range(1, epochs + 1):
    for i, (real_imgs, _) in enumerate(dataloader):

        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)

        # === Train Critic === #
        for _ in range(n_critic):
            z = torch.randn(b_size, z_dim, device=device)
            fake_imgs = generator(z).detach()

            loss_C = -torch.mean(critic(real_imgs)) + torch.mean(critic(fake_imgs))

            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()

            # Weight clipping for Lipschitz constraint
            for p in critic.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # === Train Generator === #
        z = torch.randn(b_size, z_dim, device=device)
        gen_imgs = generator(z)
        loss_G = -torch.mean(critic(gen_imgs))

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[Critic: {loss_C.item():.4f}] [Gen: {loss_G.item():.4f}]")

    # Save samples every epoch
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, device=device)
        samples = generator(z)
        samples = samples * 0.5 + 0.5  # Denormalize
        save_image(samples, f"wgan_dcgan/epoch_{epoch}.png", nrow=8)
    generator.train()