In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [44]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.disc(x)

In [45]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)

In [46]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epochs = 100

In [47]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

dataset = datasets.MNIST(root="datasets/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        
        ### Train Discriminator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph = True)
        opt_disc.step()
        
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(f"Epoch[{epoch}/{num_epochs}]")
            print(f"LossD:{lossD : .4f}, LossG: {lossG: .4f}")
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize = True)
                img_grid_real = torchvision.utils.make_grid(data, normalize = True)
                
                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step = step
                )
                
                writer_real.add_image(
                    "Mnist real Images", img_grid_real, global_step = step
                )
                
                step += 1

Epoch[0/100]
LossD: 0.6788, LossG:  0.6759
Epoch[1/100]
LossD: 0.2556, LossG:  1.6694
Epoch[2/100]
LossD: 0.4484, LossG:  1.1059
Epoch[3/100]
LossD: 0.7603, LossG:  0.6134
Epoch[4/100]
LossD: 0.7763, LossG:  0.8099
Epoch[5/100]
LossD: 0.6454, LossG:  0.8432
Epoch[6/100]
LossD: 0.9390, LossG:  0.6872
Epoch[7/100]
LossD: 1.0188, LossG:  0.6264
Epoch[8/100]
LossD: 0.7591, LossG:  0.7947
Epoch[9/100]
LossD: 0.6911, LossG:  0.9852
Epoch[10/100]
LossD: 0.6343, LossG:  0.9749
Epoch[11/100]
LossD: 0.6710, LossG:  0.6978
Epoch[12/100]
LossD: 0.6792, LossG:  0.8974
Epoch[13/100]
LossD: 0.4535, LossG:  1.3261
Epoch[14/100]
LossD: 0.9367, LossG:  0.7486
Epoch[15/100]
LossD: 0.5061, LossG:  1.2024
Epoch[16/100]
LossD: 0.8755, LossG:  0.8172
Epoch[17/100]
LossD: 0.6001, LossG:  0.9288
Epoch[18/100]
LossD: 0.7858, LossG:  1.0303
Epoch[19/100]
LossD: 0.7074, LossG:  0.9336
Epoch[20/100]
LossD: 0.4587, LossG:  1.1521
Epoch[21/100]
LossD: 0.5521, LossG:  1.1424
Epoch[22/100]
LossD: 0.6587, LossG:  1.006