In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
# Custom Dataset
class HandwrittenDigitsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        for digit in range(10):
            digit_dir = os.path.join(root_dir, str(digit), str(digit))
            for img_name in os.listdir(digit_dir):
                if img_name.endswith('.png'):
                    self.image_paths.append(os.path.join(digit_dir, img_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image

In [3]:
# Generator with batch normalization
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 1024),  # Larger initial layer
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),  # Using LeakyReLU instead of ReLU
            nn.Dropout(0.2),  # Less dropout
            
            nn.Linear(1024, 2048),  # Larger middle layer
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            
            nn.Linear(2048, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Discriminator
# Remove sigmoid from discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 256),  # smaller network
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),  # add dropout
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),  # add dropout
            nn.Linear(128, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Gradient penalty function
def compute_gradient_penalty(D, real_samples, fake_samples, device):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1, device=device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [6]:
# Hyperparameters
latent_dim = 100
n_critic = 1  # Train D and G equally
lambda_gp = 1  # Reduce gradient penalty weight
lr_d = 0.000003  # Extremely low learning rate for discriminator
lr_g = 0.000015   # Generator can learn 5x faster
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 50

# Dataset and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = HandwrittenDigitsDataset(root_dir="D:\\HandwrittenDigitsDataset\\dataset", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Hyperparameters for WGAN
n_critic = 2  # Train discriminator more often
lambda_gp = 5  # Gradient penalty weight
optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=5, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=5, gamma=0.5)

In [7]:
# Training
for epoch in range(num_epochs):
    for i, imgs in enumerate(dataloader):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)

        noise_factor = max(0.1 * (1 - epoch/20), 0)  # Starts at 0.1, goes to 0 by epoch 20

        # Train Discriminator
        for _ in range(n_critic):
            optimizer_D.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)
            
            # Calculate discriminator outputs
            real_imgs_noisy = real_imgs + noise_factor * torch.randn_like(real_imgs)
            fake_imgs_noisy = fake_imgs.detach() + noise_factor * torch.randn_like(fake_imgs)
            real_validity = discriminator(real_imgs_noisy)
            fake_validity = discriminator(fake_imgs_noisy)
            
            # Calculate gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs_noisy, fake_imgs_noisy, device)

            real_target = torch.ones_like(real_validity).to(device) * 0.9  # Smooth positive labels
            fake_target = torch.ones_like(fake_validity).to(device) * -0.9  # Smooth negative labels
            d_real_loss = torch.mean(nn.MSELoss()(real_validity, real_target))
            d_fake_loss = torch.mean(nn.MSELoss()(fake_validity, fake_target))
            d_loss = d_real_loss + d_fake_loss + lambda_gp * gradient_penalty            
            d_loss.backward()
            optimizer_D.step()
            
            # Remove weight clamping - this shouldn't be used with gradient penalty

        # Train Generator
        optimizer_G.zero_grad()
        
        # Generate new fake images (important to regenerate after D update)
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        
        fake_validity = discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)
        
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            

    # Save generated images
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_imgs = generator(torch.randn(16, latent_dim).to(device))
            fake_imgs = fake_imgs.cpu().numpy()
            fig, axes = plt.subplots(4, 4, figsize=(8, 8))
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(fake_imgs[i, 0], cmap='gray')
                ax.axis('off')
            plt.savefig(f'generated_digits_epoch_{epoch}.png')
            plt.close()

    scheduler_G.step()
    scheduler_D.step()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[Epoch 0/50] [Batch 0/1684] [D loss: 5.3101] [G loss: 0.1106]
[Epoch 0/50] [Batch 100/1684] [D loss: 7.2247] [G loss: -0.8107]
[Epoch 0/50] [Batch 200/1684] [D loss: 6.3433] [G loss: -0.6213]
[Epoch 0/50] [Batch 300/1684] [D loss: 5.2045] [G loss: -0.2506]
[Epoch 0/50] [Batch 400/1684] [D loss: 4.9039] [G loss: -0.1336]
[Epoch 0/50] [Batch 500/1684] [D loss: 4.8631] [G loss: -0.1412]
[Epoch 0/50] [Batch 600/1684] [D loss: 4.8231] [G loss: -0.1510]
[Epoch 0/50] [Batch 700/1684] [D loss: 5.0398] [G loss: -0.2284]
[Epoch 0/50] [Batch 800/1684] [D loss: 5.1428] [G loss: -0.2538]
[Epoch 0/50] [Batch 900/1684] [D loss: 5.1062] [G loss: -0.2522]
[Epoch 0/50] [Batch 1000/1684] [D loss: 4.9142] [G loss: -0.1666]
[Epoch 0/50] [Batch 1100/1684] [D loss: 4.6879] [G loss: -0.1252]
[Epoch 0/50] [Batch 1200/1684] [D loss: 4.7756] [G loss: -0.1387]
[Epoch 0/50] [Batch 1300/1684] [D loss: 4.7602] [G loss: -0.1314]
[Epoch 0/50] [Batch 1400/1684] [D loss: 4.6766] [G loss: -0.1452]
[Epoch 0/50] [Batch 150

In [8]:
# Save models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')