In [1]:
%rm -rf runs/

In [2]:
# Load the TensorBoard notebook extension
%reload_ext tensorboard

In [3]:
print("Simple GAN")

Simple GAN


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)



In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

Hyperparameters etc. GANs are incredibly sensitive to hyperparameters. 

In [7]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
lr = 1e-4
z_dim = 64      # 128, 256
image_dim = 28 * 28 * 1   # 784
batch_size = 32
num_epochs = 120

In [8]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0




BCELoss docs: https://docs.pytorch.org/docs/stable/generated/torch.nn.BCELoss.html

In [9]:
%tensorboard --logdir=runs/GAN_MNIST/ --reload_interval=5

Reusing TensorBoard on port 6007 (pid 82664), started 0:29:28 ago. (Use '!kill 82664' to kill it.)

In [10]:
for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train the Discriminator: max log(D(real)) + log(1-D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)    # retain_graph = True -> important for retaining the intermediate gradients during backprop
        opt_disc.step()


        ### Train the Generator : min log(1-D(G(z)))  <-> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        ## Some additional codes for tensorboard
        if batch_idx == 0 :
            print(
                f"Epoch [{epoch}/{num_epochs}] \n",
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "MNIST Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "MNIST real Images", img_grid_real, global_step=step
                )
                step += 1

writer_real.close()
writer_fake.close()


Epoch [0/120] 
 Loss D: 0.7052, Loss G: 0.6845
Epoch [1/120] 
 Loss D: 0.5969, Loss G: 0.7816
Epoch [2/120] 
 Loss D: 0.6187, Loss G: 0.8218
Epoch [3/120] 
 Loss D: 0.6739, Loss G: 0.6970
Epoch [4/120] 
 Loss D: 0.5258, Loss G: 0.9459
Epoch [5/120] 
 Loss D: 0.5844, Loss G: 0.7974
Epoch [6/120] 
 Loss D: 0.2140, Loss G: 1.8844
Epoch [7/120] 
 Loss D: 0.5103, Loss G: 1.0086
Epoch [8/120] 
 Loss D: 0.5927, Loss G: 0.8608
Epoch [9/120] 
 Loss D: 0.2537, Loss G: 1.6945
Epoch [10/120] 
 Loss D: 0.4924, Loss G: 1.1755
Epoch [11/120] 
 Loss D: 0.3635, Loss G: 1.4054
Epoch [12/120] 
 Loss D: 0.2828, Loss G: 1.6051
Epoch [13/120] 
 Loss D: 0.2997, Loss G: 1.6120
Epoch [14/120] 
 Loss D: 0.1684, Loss G: 2.1208
Epoch [15/120] 
 Loss D: 0.3102, Loss G: 1.8995
Epoch [16/120] 
 Loss D: 0.2243, Loss G: 1.8348
Epoch [17/120] 
 Loss D: 0.1277, Loss G: 2.4117
Epoch [18/120] 
 Loss D: 0.1780, Loss G: 2.4006
Epoch [19/120] 
 Loss D: 0.2562, Loss G: 2.3403
Epoch [20/120] 
 Loss D: 0.1545, Loss G: 2.5083
Ep

Things to try :
1. What happens if you use larger network ?
2. Better normalization with BatchNorm
3. Different learning rate (is there a better one) ?
4. Change the architecture to a CNN

In [11]:
# Run this command in terminal to start the tensorboad: 
# tensorboard --logdir=runs/GAN_MNIST/
print(device)

mps


In [12]:
%tensorboard --logdir=runs/GAN_MNIST/ --reload_interval=5

Reusing TensorBoard on port 6007 (pid 82664), started 0:41:10 ago. (Use '!kill 82664' to kill it.)