In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader  
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
  
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim , 128),
        nn.LeakyReLU(0.1) , # Slope of 0.1 , leaky relu performs better in GANs
        nn.Linear(128 , 1),
        nn.Sigmoid(),
    )
  
  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self , z_dim , img_dim):
    super().__init__()
    self.gen = nn.Sequential(
      nn.Linear(z_dim , 256),
        nn.LeakyReLU(0.1) , # Slope of 0.1 , leaky relu performs better in GANs
        nn.Linear(256 , img_dim),
        nn.Tanh(),
    )
  
  def forward(self,x):
    return self.gen(x)


In [13]:
# Goes under config.py 


# Hyperparamas
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device : " , device)
lr = 3e-4
z_dim = 64 # 128 , 256 
img_dim = 28*28*1 # 784
batch_size = 32
num_epochs = 5

# defining variables
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim , img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

# transforms 
import torchvision.transforms as transforms
transforms = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5,),(0.5,))
      ]
)

# datasets 
dataset = datasets.MNIST(root = "dataset/" , transform = transforms , download = True)
loader = DataLoader(dataset , batch_size = batch_size , shuffle = True) 

# optimizers 
optim_disc = optim.Adam(disc.parameters() , lr = lr)
optim_gen = optim.Adam(gen.parameters() , lr = lr)

# Loss Fn
criterion = nn.BCELoss()

# Tensorboard
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0 


Device :  cuda


In [5]:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

Instructions for updating:
non-resource variables are not supported in the long term


In [10]:
%load_ext tensorboard
%tensorboard --logdir /content/runs/GAN_MNIST/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


<IPython.core.display.Javascript object>

In [None]:
%ls

In [None]:

!tensorboard dev upload --logdir "/content/runs/GAN_MNIST/fake/events.out.tfevents.1672666248.408f0282d59e.267.0"--one_shot

Upload started and will continue reading any new data as it's added to the logdir.

To stop uploading, press Ctrl-C.

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/vcxxsNSBS521nBMn9s4MTg/

[1m[2023-01-02T13:42:38][0m Started scanning logdir.


In [14]:
for epoch in range(num_epochs):
  for batch_idx , (real , _) in enumerate(loader):
    
    real = real.view(-1,784).to(device) # resizing images
    batch_size = real.shape[0]

    # Discriminator training : max : log(D(real)) + log(1-D(G(z)))
    
    noise = torch.randn(batch_size,z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    disc_fake = disc(fake).view(-1)

    lossD_real = criterion(disc_real , torch.ones_like(disc_real))
    lossD_fake = criterion(disc_fake , torch.zeros_like(disc_fake))

    lossD = (lossD_real + lossD_fake)/2

    disc.zero_grad()
    lossD.backward(retain_graph=True) # clears the gradients from cache
    optim_disc.step()

    # Generator training : min log(1-D(G(z))) -> max log(D(G(z)))
    
    output = disc(fake).view(-1)
    lossG = criterion(output , torch.ones_like(output))
    gen.zero_grad()
    lossG.backward() # don't need to retain since it's not needed anymore
    optim_gen.step()


    # Tensorboard things : 
    if batch_idx == 0 :
      print(f"""{epoch}/{num_epochs} completed 
            Loss D : {lossD} , lossG : {lossG} """)

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1,1,28,28)
        data = real.reshape(-1,1,28,28)

        img_fake_grid = torchvision.utils.make_grid(fake,normalize = True)
        img_real_grid = torchvision.utils.make_grid(real,normalize = True)

        writer_fake.add_image("MNIST Fake images " , img_fake_grid , global_step = step)
        writer_real.add_image("MNIST Real images " , img_real_grid , global_step = step)

        step+=1


0/5 completed 
            Loss D : 0.6914538145065308 , lossG : 0.7370061278343201 
1/5 completed 
            Loss D : 0.3003813922405243 , lossG : 1.6012623310089111 
2/5 completed 
            Loss D : 0.48966288566589355 , lossG : 1.3795697689056396 
3/5 completed 
            Loss D : 0.5448020696640015 , lossG : 0.9958481788635254 
4/5 completed 
            Loss D : 0.6209837794303894 , lossG : 0.7911437153816223 


In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
plt.figure(figsize = (20,20))
plt.imshow(img_real_grid.permute(1, 2, 0))

In [None]:
img = torchvision.transforms.ToPILImage()(img_real_grid)
img.show()