In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()

        self.mod = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.mod(X)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)



In [3]:
# Hyperparams
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
epochs = 50

disc = Discriminator(image_dim)
gene = Generator(z_dim, image_dim)
fixed_noise = torch.randn((batch_size, z_dim))

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [None]:


dataset = datasets.MNIST(root = "dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size, shuffle=True)

optim_disc = optim.Adam(disc.parameters(), lr =lr)
optim_gene = optim.Adam(gene.parameters(), lr =lr)
critereon = nn.BCELoss()

writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 1

for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784)
        batch_size = real.shape[0]

        ### Training Disc: 
        noise = torch.randn(batch_size, z_dim)
        fake = gene(noise)
        disc_real = disc(real).view(-1)
        lossD_real = critereon(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = critereon(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2 
        disc.zero_grad()
        lossD.backward(retain_graph = True)
        optim_disc.step()

        ### Training Gene:

        output = disc(fake).view(-1)
        lossG = critereon(output, torch.ones_like(output))
        gene.zero_grad()
        lossG.backward()
        optim_gene.step()


        ### Code for TENSORBOARD
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{epochs}]; LOSS_D: {lossD:.4f}; LOSS_G: {lossG:.4f} "
            )
            with torch.no_grad():
                fake = gene(fixed_noise).reshape(-1, 1, 28, 28) # 28x28x1 is the picture dimensions
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step+=1
        

Epoch [0/50]; LOSS_D: 0.6686; LOSS_G: 0.7528 
Epoch [1/50]; LOSS_D: 0.4858; LOSS_G: 0.9529 
Epoch [2/50]; LOSS_D: 0.6602; LOSS_G: 0.8707 
Epoch [3/50]; LOSS_D: 0.9369; LOSS_G: 0.5660 
Epoch [4/50]; LOSS_D: 0.4630; LOSS_G: 1.3248 
Epoch [5/50]; LOSS_D: 0.3622; LOSS_G: 1.4009 
Epoch [6/50]; LOSS_D: 0.3937; LOSS_G: 1.5424 
Epoch [7/50]; LOSS_D: 0.8544; LOSS_G: 0.8167 
Epoch [8/50]; LOSS_D: 0.3180; LOSS_G: 1.8629 
Epoch [9/50]; LOSS_D: 0.5092; LOSS_G: 1.1593 
Epoch [10/50]; LOSS_D: 0.6733; LOSS_G: 1.3894 
Epoch [11/50]; LOSS_D: 0.5260; LOSS_G: 1.5212 
Epoch [12/50]; LOSS_D: 0.6039; LOSS_G: 0.9596 
Epoch [13/50]; LOSS_D: 0.6464; LOSS_G: 0.9831 
Epoch [14/50]; LOSS_D: 0.5588; LOSS_G: 1.3560 
Epoch [15/50]; LOSS_D: 0.7834; LOSS_G: 0.8340 
Epoch [16/50]; LOSS_D: 0.5843; LOSS_G: 1.5093 
Epoch [17/50]; LOSS_D: 0.7564; LOSS_G: 1.0407 
Epoch [18/50]; LOSS_D: 0.5870; LOSS_G: 0.9592 
Epoch [19/50]; LOSS_D: 0.4748; LOSS_G: 1.2606 
Epoch [20/50]; LOSS_D: 0.6965; LOSS_G: 0.9699 
Epoch [21/50]; LOSS_D: 