In [1]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torchvision import transforms
from torchvision import utils as vutils
from torch.utils.data import DataLoader

In [2]:
save_dir = "./runs/GAN"
data_dir = "/home/pervinco/Datasets/torch_mnist"

epochs = 100
batch_size = 64
d_lr = 0.0001
g_lr = 0.0001

disc_iter = 5
input_dim = 1
latent_dim = 64
hidden_dim = 256
image_size = 784

num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.model(x)

        return x

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, image_size),
            nn.Tanh()
        )

    def forward(self, z):
        z = self.model(z)
        
        return z

In [5]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5], std=[0.5])])

train_dataset = torchvision.datasets.MNIST(root=data_dir,
                                           train=True,
                                           transform=transform,
                                           download=True)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
def save_fake_images(epoch, G, fixed_noise, channels, height, width, num_images=64):
    with torch.no_grad():  # No gradients required
        # Generate fake images from the fixed noise
        fake_images = G(fixed_noise).detach().cpu()
        # Reshape the images to (batch_size, channels, height, width)
        fake_images = fake_images.view(fake_images.size(0), channels, height, width)
        # Make sure the number of images does not exceed the batch size
        fake_images = fake_images[:min(num_images, fake_images.size(0))]
        # Create a grid of images
        grid = vutils.make_grid(fake_images, padding=2, normalize=True)
        # Convert grid to a numpy array
        np_grid = grid.numpy()
        # Transpose the image array (PyTorch tensors have channel in the first dimension, while matplotlib expects it in the third)
        plt.imshow(np.transpose(np_grid, (1, 2, 0)))

    # Save the current figure
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    plt.savefig(f"{save_dir}/Epoch_{epoch}_Fake.png")
    plt.close()

In [7]:
D = Discriminator(input_dim=image_size)
G = Generator(latent_dim=latent_dim)

D = D.to(device)
G = G.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=g_lr)

fixed_noise = torch.randn(batch_size, latent_dim, device=device)
for epoch in range(epochs):
    for idx, (images, _) in enumerate(tqdm(train_dataloader, desc="Train", leave=False)):
        real_images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train the discriminator
        # Initialize gradients for both optimizers
        d_optimizer.zero_grad()
        
        # Forward pass for real images
        real_outputs = D(real_images)
        d_real_loss = criterion(real_outputs, real_labels)

        # Sample noise and generate fake data
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = G(z)
        fake_outputs = D(fake_images.detach())
        d_fake_loss = criterion(fake_outputs, fake_labels)

        # Backward pass and optimize
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train the generator every k steps
        if (idx+1) % disc_iter == 0:
            # Reset gradients for generator optimizer
            g_optimizer.zero_grad()
            
            # Sample noise and generate fake data
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = G(z)
            # We want to fool the discriminator, so we assign real labels
            outputs = D(fake_images)
            g_loss = criterion(outputs, real_labels)

            # Backward pass and optimize
            g_loss.backward()
            g_optimizer.step()

    save_fake_images(epoch+1, G, fixed_noise, input_dim, 28, 28)
    # Print losses and scores
    print(f"Epoch[{epoch}/{epochs}], Step[{idx+1}/{len(train_dataloader)}]")
    print(f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
    print(f"Real Score: {real_outputs.mean().item():.4f}, Fake Score: {fake_outputs.mean().item():.4f}")

# Save the model checkpoints 
torch.save(G.state_dict(), f'{save_dir}/G.ckpt')
torch.save(D.state_dict(), f'{save_dir}/D.ckpt')

                                                         

Epoch[0/100], Step[937/937]
D_loss: 0.0057, G_loss: 5.9318
Real Score: 1.0000, Fake Score: 0.0057


                                                         

Epoch[1/100], Step[937/937]
D_loss: 0.0009, G_loss: 8.1174
Real Score: 1.0000, Fake Score: 0.0008


                                                         

Epoch[2/100], Step[937/937]
D_loss: 0.0004, G_loss: 9.7543
Real Score: 1.0000, Fake Score: 0.0004


                                                         

Epoch[3/100], Step[937/937]
D_loss: 0.0003, G_loss: 10.8637
Real Score: 0.9999, Fake Score: 0.0002


                                                         

Epoch[4/100], Step[937/937]
D_loss: 0.0075, G_loss: 9.1340
Real Score: 0.9941, Fake Score: 0.0011


                                                         

Epoch[5/100], Step[937/937]
D_loss: 0.0046, G_loss: 6.5216
Real Score: 0.9991, Fake Score: 0.0037


                                                         

Epoch[6/100], Step[937/937]
D_loss: 0.0416, G_loss: 5.3355
Real Score: 0.9819, Fake Score: 0.0074


                                                         

Epoch[7/100], Step[937/937]
D_loss: 0.0197, G_loss: 6.2031
Real Score: 0.9988, Fake Score: 0.0181


                                                         

Epoch[8/100], Step[937/937]
D_loss: 0.0405, G_loss: 7.4313
Real Score: 0.9789, Fake Score: 0.0039


                                                         

Epoch[9/100], Step[937/937]
D_loss: 0.0034, G_loss: 5.9060
Real Score: 0.9999, Fake Score: 0.0032


                                                         

Epoch[10/100], Step[937/937]
D_loss: 0.0151, G_loss: 5.0694
Real Score: 0.9961, Fake Score: 0.0109


                                                         

Epoch[11/100], Step[937/937]
D_loss: 0.0098, G_loss: 7.9342
Real Score: 0.9922, Fake Score: 0.0008


                                                         

Epoch[12/100], Step[937/937]
D_loss: 0.0062, G_loss: 6.3931
Real Score: 0.9996, Fake Score: 0.0058


                                                         

Epoch[13/100], Step[937/937]
D_loss: 0.0107, G_loss: 9.4196
Real Score: 0.9907, Fake Score: 0.0001


                                                         

Epoch[14/100], Step[937/937]
D_loss: 0.0057, G_loss: 7.0133
Real Score: 0.9961, Fake Score: 0.0015


                                                         

Epoch[15/100], Step[937/937]
D_loss: 0.0012, G_loss: 8.4405
Real Score: 0.9994, Fake Score: 0.0006


                                                         

Epoch[16/100], Step[937/937]
D_loss: 0.0007, G_loss: 9.6763
Real Score: 0.9995, Fake Score: 0.0003


                                                         

Epoch[17/100], Step[937/937]
D_loss: 0.0012, G_loss: 8.1091
Real Score: 1.0000, Fake Score: 0.0012


                                                         

Epoch[18/100], Step[937/937]
D_loss: 0.0034, G_loss: 8.4433
Real Score: 0.9999, Fake Score: 0.0033


                                                         

Epoch[19/100], Step[937/937]
D_loss: 0.0143, G_loss: 8.2384
Real Score: 0.9927, Fake Score: 0.0053


                                                         

Epoch[20/100], Step[937/937]
D_loss: 0.0014, G_loss: 8.6660
Real Score: 1.0000, Fake Score: 0.0014


                                                         

Epoch[21/100], Step[937/937]
D_loss: 0.0005, G_loss: 10.7158
Real Score: 1.0000, Fake Score: 0.0005


                                                         

Epoch[22/100], Step[937/937]
D_loss: 0.0107, G_loss: 8.9658
Real Score: 0.9925, Fake Score: 0.0005


                                                         

Epoch[23/100], Step[937/937]
D_loss: 0.0031, G_loss: 5.2660
Real Score: 0.9999, Fake Score: 0.0030


                                                         

Epoch[24/100], Step[937/937]
D_loss: 0.0344, G_loss: 32.5733
Real Score: 0.9857, Fake Score: 0.0000


                                                         

Epoch[25/100], Step[937/937]
D_loss: 0.0106, G_loss: 7.5916
Real Score: 1.0000, Fake Score: 0.0101


                                                         

Epoch[26/100], Step[937/937]
D_loss: 0.0000, G_loss: 17.3095
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[27/100], Step[937/937]
D_loss: 0.0000, G_loss: 16.4667
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[28/100], Step[937/937]
D_loss: 0.0066, G_loss: 8.5343
Real Score: 1.0000, Fake Score: 0.0065


                                                         

Epoch[29/100], Step[937/937]
D_loss: 0.0056, G_loss: 5.7647
Real Score: 1.0000, Fake Score: 0.0055


                                                         

Epoch[30/100], Step[937/937]
D_loss: 0.0003, G_loss: 10.5102
Real Score: 1.0000, Fake Score: 0.0003


                                                         

Epoch[31/100], Step[937/937]
D_loss: 0.0011, G_loss: 12.5552
Real Score: 1.0000, Fake Score: 0.0011


                                                         

Epoch[32/100], Step[937/937]
D_loss: 0.0000, G_loss: 12.0766
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[33/100], Step[937/937]
D_loss: 0.0082, G_loss: 12.4679
Real Score: 1.0000, Fake Score: 0.0080


                                                         

Epoch[34/100], Step[937/937]
D_loss: 0.0036, G_loss: 11.0986
Real Score: 1.0000, Fake Score: 0.0036


                                                         

Epoch[35/100], Step[937/937]
D_loss: 0.0002, G_loss: 11.6792
Real Score: 1.0000, Fake Score: 0.0002


                                                         

Epoch[36/100], Step[937/937]
D_loss: 0.0014, G_loss: 9.7238
Real Score: 1.0000, Fake Score: 0.0013


                                                         

Epoch[37/100], Step[937/937]
D_loss: 0.0000, G_loss: 30.7651
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[38/100], Step[937/937]
D_loss: 0.0000, G_loss: 40.6558
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[39/100], Step[937/937]
D_loss: 0.0000, G_loss: 15.0179
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[40/100], Step[937/937]
D_loss: 0.0045, G_loss: 9.7482
Real Score: 0.9998, Fake Score: 0.0042


                                                         

Epoch[41/100], Step[937/937]
D_loss: 0.0001, G_loss: 12.6034
Real Score: 1.0000, Fake Score: 0.0001


                                                         

Epoch[42/100], Step[937/937]
D_loss: 0.0000, G_loss: 16.9848
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[43/100], Step[937/937]
D_loss: 0.0002, G_loss: 13.8747
Real Score: 1.0000, Fake Score: 0.0002


                                                         

Epoch[44/100], Step[937/937]
D_loss: 0.0004, G_loss: 10.0178
Real Score: 1.0000, Fake Score: 0.0004


                                                         

Epoch[45/100], Step[937/937]
D_loss: 0.0012, G_loss: 12.0683
Real Score: 1.0000, Fake Score: 0.0012


                                                         

Epoch[46/100], Step[937/937]
D_loss: 0.0000, G_loss: 13.8328
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[47/100], Step[937/937]
D_loss: 0.0002, G_loss: 11.7005
Real Score: 1.0000, Fake Score: 0.0002


                                                         

Epoch[48/100], Step[937/937]
D_loss: 0.0000, G_loss: 45.6710
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[49/100], Step[937/937]
D_loss: 0.0000, G_loss: 36.5357
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[50/100], Step[937/937]
D_loss: 0.0000, G_loss: 20.2019
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[51/100], Step[937/937]
D_loss: 0.0009, G_loss: 10.8818
Real Score: 1.0000, Fake Score: 0.0009


                                                         

Epoch[52/100], Step[937/937]
D_loss: 0.0147, G_loss: 11.1207
Real Score: 1.0000, Fake Score: 0.0102


                                                         

Epoch[53/100], Step[937/937]
D_loss: 0.0056, G_loss: 12.2174
Real Score: 0.9965, Fake Score: 0.0016


                                                         

Epoch[54/100], Step[937/937]
D_loss: 0.0019, G_loss: 9.5434
Real Score: 1.0000, Fake Score: 0.0019


                                                         

Epoch[55/100], Step[937/937]
D_loss: 0.0000, G_loss: 35.8317
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[56/100], Step[937/937]
D_loss: 0.0001, G_loss: 12.8212
Real Score: 1.0000, Fake Score: 0.0001


                                                         

Epoch[57/100], Step[937/937]
D_loss: 0.0033, G_loss: 14.6280
Real Score: 0.9971, Fake Score: 0.0000


                                                         

Epoch[58/100], Step[937/937]
D_loss: 0.0372, G_loss: 38.1420
Real Score: 0.9835, Fake Score: 0.0000


                                                         

Epoch[59/100], Step[937/937]
D_loss: 0.0242, G_loss: 9.5778
Real Score: 1.0000, Fake Score: 0.0141


                                                         

Epoch[60/100], Step[937/937]
D_loss: 0.0047, G_loss: 9.6316
Real Score: 0.9963, Fake Score: 0.0007


                                                         

Epoch[61/100], Step[937/937]
D_loss: 0.0006, G_loss: 13.4637
Real Score: 1.0000, Fake Score: 0.0006


                                                         

Epoch[62/100], Step[937/937]
D_loss: 0.0006, G_loss: 10.2802
Real Score: 1.0000, Fake Score: 0.0006


                                                         

Epoch[63/100], Step[937/937]
D_loss: 0.0001, G_loss: 13.0316
Real Score: 1.0000, Fake Score: 0.0000


                                                         

Epoch[64/100], Step[937/937]
D_loss: 0.0001, G_loss: 18.6753
Real Score: 1.0000, Fake Score: 0.0001


                                                         

Epoch[65/100], Step[937/937]
D_loss: 0.0049, G_loss: 9.4286
Real Score: 0.9968, Fake Score: 0.0014


                                                         

Epoch[66/100], Step[937/937]
D_loss: 0.0010, G_loss: 7.7379
Real Score: 1.0000, Fake Score: 0.0010


                                                         

Epoch[67/100], Step[937/937]
D_loss: 0.0078, G_loss: 11.7302
Real Score: 0.9939, Fake Score: 0.0004


                                                         

Epoch[68/100], Step[937/937]
D_loss: 0.0094, G_loss: 9.5192
Real Score: 0.9952, Fake Score: 0.0035


                                                         

Epoch[69/100], Step[937/937]
D_loss: 0.0030, G_loss: 8.6264
Real Score: 0.9998, Fake Score: 0.0027


                                                         

Epoch[70/100], Step[937/937]
D_loss: 0.0376, G_loss: 12.8541
Real Score: 0.9759, Fake Score: 0.0000


                                                         

Epoch[71/100], Step[937/937]
D_loss: 0.0052, G_loss: 7.9333
Real Score: 1.0000, Fake Score: 0.0050


                                                         

Epoch[72/100], Step[937/937]
D_loss: 0.0032, G_loss: 9.4014
Real Score: 0.9993, Fake Score: 0.0025


                                                         

Epoch[73/100], Step[937/937]
D_loss: 0.0133, G_loss: 6.0859
Real Score: 0.9923, Fake Score: 0.0027


                                                         

Epoch[74/100], Step[937/937]
D_loss: 0.0789, G_loss: 10.0151
Real Score: 0.9741, Fake Score: 0.0115


                                                         

Epoch[75/100], Step[937/937]
D_loss: 0.0030, G_loss: 9.4570
Real Score: 0.9999, Fake Score: 0.0028


                                                         

Epoch[76/100], Step[937/937]
D_loss: 0.0048, G_loss: 7.9696
Real Score: 0.9980, Fake Score: 0.0027


                                                         

Epoch[77/100], Step[937/937]
D_loss: 0.0010, G_loss: 11.1459
Real Score: 0.9998, Fake Score: 0.0008


                                                         

Epoch[78/100], Step[937/937]
D_loss: 0.0078, G_loss: 7.1857
Real Score: 1.0000, Fake Score: 0.0071


                                                         

Epoch[79/100], Step[937/937]
D_loss: 0.0113, G_loss: 7.4484
Real Score: 0.9998, Fake Score: 0.0099


                                                         

Epoch[80/100], Step[937/937]
D_loss: 0.0091, G_loss: 10.3114
Real Score: 0.9934, Fake Score: 0.0017


                                                         

Epoch[81/100], Step[937/937]
D_loss: 0.0004, G_loss: 10.4810
Real Score: 0.9999, Fake Score: 0.0004


                                                         

Epoch[82/100], Step[937/937]
D_loss: 0.1020, G_loss: 9.7844
Real Score: 0.9775, Fake Score: 0.0033


                                                         

Epoch[83/100], Step[937/937]
D_loss: 0.0779, G_loss: 8.0603
Real Score: 0.9763, Fake Score: 0.0019


                                                         

Epoch[84/100], Step[937/937]
D_loss: 0.0123, G_loss: 8.5339
Real Score: 0.9921, Fake Score: 0.0018


                                                         

Epoch[85/100], Step[937/937]
D_loss: 0.0274, G_loss: 7.9256
Real Score: 0.9987, Fake Score: 0.0146


                                                         

Epoch[86/100], Step[937/937]
D_loss: 0.0866, G_loss: 5.1156
Real Score: 0.9948, Fake Score: 0.0620


                                                         

Epoch[87/100], Step[937/937]
D_loss: 0.0047, G_loss: 9.5579
Real Score: 0.9959, Fake Score: 0.0005


                                                         

Epoch[88/100], Step[937/937]
D_loss: 0.0985, G_loss: 7.6440
Real Score: 0.9818, Fake Score: 0.0038


                                                         

Epoch[89/100], Step[937/937]
D_loss: 0.0477, G_loss: 5.2235
Real Score: 0.9779, Fake Score: 0.0083


                                                         

Epoch[90/100], Step[937/937]
D_loss: 0.1132, G_loss: 6.9869
Real Score: 0.9538, Fake Score: 0.0048


                                                         

Epoch[91/100], Step[937/937]
D_loss: 0.0193, G_loss: 5.7310
Real Score: 0.9985, Fake Score: 0.0170


                                                         

Epoch[92/100], Step[937/937]
D_loss: 0.0020, G_loss: 6.6950
Real Score: 0.9998, Fake Score: 0.0018


                                                         

Epoch[93/100], Step[937/937]
D_loss: 0.0029, G_loss: 7.1699
Real Score: 0.9995, Fake Score: 0.0024


                                                         

Epoch[94/100], Step[937/937]
D_loss: 0.0519, G_loss: 4.4566
Real Score: 0.9864, Fake Score: 0.0266


                                                         

Epoch[95/100], Step[937/937]
D_loss: 0.0644, G_loss: 5.2233
Real Score: 0.9652, Fake Score: 0.0082


                                                         

Epoch[96/100], Step[937/937]
D_loss: 0.1017, G_loss: 6.1498
Real Score: 0.9659, Fake Score: 0.0182


                                                         

Epoch[97/100], Step[937/937]
D_loss: 0.0244, G_loss: 6.0906
Real Score: 0.9920, Fake Score: 0.0145


                                                         

Epoch[98/100], Step[937/937]
D_loss: 0.4304, G_loss: 11.7591
Real Score: 0.9112, Fake Score: 0.0000


                                                         

Epoch[99/100], Step[937/937]
D_loss: 0.0555, G_loss: 4.2978
Real Score: 0.9825, Fake Score: 0.0085


