In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [2]:
# Custom Dataset
class HandwrittenDigitsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        for digit in range(10):
            digit_dir = os.path.join(root_dir, str(digit), str(digit))
            for img_name in os.listdir(digit_dir):
                if img_name.endswith('.png'):
                    self.image_paths.append(os.path.join(digit_dir, img_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image




In [3]:
# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [6]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
num_epochs = 50
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Dataset and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = HandwrittenDigitsDataset(root_dir="D:\\HandwrittenDigitsDataset\\dataset", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

In [8]:
# Training
for epoch in range(num_epochs):
    for i, imgs in enumerate(dataloader):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        real_label = torch.ones(batch_size, 1).to(device)
        fake_label = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        real_validity = discriminator(real_imgs)
        fake_validity = discriminator(fake_imgs.detach())
        d_loss = adversarial_loss(real_validity, real_label) + adversarial_loss(fake_validity, fake_label)
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_validity = discriminator(fake_imgs)
        g_loss = adversarial_loss(fake_validity, real_label)
        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # Save generated images
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_imgs = generator(torch.randn(16, latent_dim).to(device))
            fake_imgs = fake_imgs.cpu().numpy()
            import matplotlib.pyplot as plt
            fig, axes = plt.subplots(4, 4, figsize=(8, 8))
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(fake_imgs[i, 0], cmap='gray')
                ax.axis('off')
            plt.savefig(f'generated_digits_epoch_{epoch}.png')
            plt.close()

[Epoch 0/50] [Batch 0/1684] [D loss: 1.4275] [G loss: 0.6846]
[Epoch 0/50] [Batch 100/1684] [D loss: 1.4351] [G loss: 0.7305]
[Epoch 0/50] [Batch 200/1684] [D loss: 1.4320] [G loss: 0.6981]
[Epoch 0/50] [Batch 300/1684] [D loss: 1.4085] [G loss: 0.6684]
[Epoch 0/50] [Batch 400/1684] [D loss: 1.3829] [G loss: 0.6755]
[Epoch 0/50] [Batch 500/1684] [D loss: 1.3854] [G loss: 0.6817]
[Epoch 0/50] [Batch 600/1684] [D loss: 1.3797] [G loss: 0.6970]
[Epoch 0/50] [Batch 700/1684] [D loss: 1.3818] [G loss: 0.7031]
[Epoch 0/50] [Batch 800/1684] [D loss: 1.3891] [G loss: 0.6956]
[Epoch 0/50] [Batch 900/1684] [D loss: 1.3927] [G loss: 0.6800]
[Epoch 0/50] [Batch 1000/1684] [D loss: 1.3872] [G loss: 0.6839]
[Epoch 0/50] [Batch 1100/1684] [D loss: 1.3520] [G loss: 0.7114]
[Epoch 0/50] [Batch 1200/1684] [D loss: 1.3860] [G loss: 0.6926]
[Epoch 0/50] [Batch 1300/1684] [D loss: 1.3914] [G loss: 0.7771]
[Epoch 0/50] [Batch 1400/1684] [D loss: 1.3861] [G loss: 0.6900]
[Epoch 0/50] [Batch 1500/1684] [D los

In [9]:
# Save models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')