In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

In [None]:
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

In [None]:
batch_size=32
lr=3e-4
img_size=28*28*1
n_epoch=5
noise=64

In [None]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
class discriminator(nn.Module):
  def __init__(self,input_size):
    super().__init__()
    self.input_size=input_size
    self.layer1=nn.Linear(input_size,128)
    self.act1=nn.LeakyReLU(0.1)
    self.layer2=nn.Linear(128,1)
    self.act2=nn.Sigmoid()
  def forward(self,X):
    return self.act2(self.layer2(self.act1(self.layer1(X))))



In [None]:
class generator(nn.Module):
  def __init__(self,input_size,mnist_dim):
    super().__init__()
    self.input_size=input_size
    self.layer1=nn.Linear(input_size,256)
    self.act1=nn.LeakyReLU(0.1)
    self.layer2=nn.Linear(256,mnist_dim)
    self.act2=nn.Tanh()
  def forward(self,X):
    return self.act2(self.layer2(self.act1(self.layer1(X))))


  

In [None]:
disc=discriminator(img_size).to(device)
gen=generator(noise,img_size).to(device)
disc_optim=optim.Adam(disc.parameters(),lr=lr)
gen_optim=optim.Adam(gen.parameters(),lr=lr)
criterion=nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

In [None]:
n_steps=len(loader)
step=0
for epoch in range(n_epoch):
  for i ,(real,label) in enumerate(loader):
    real=torch.flatten(real,start_dim=1).to(device)
    batch_size=real.shape[0]
    # creating real and fake labels
    real_label=torch.ones(batch_size).to(device)
    fake_label=torch.zeros(batch_size).to(device)
    # output of dicriminator on real images
    out_disc_real=disc.forward(real).reshape(batch_size)
    # generating fake images
    fake_noise=torch.randn(batch_size,noise).to(device)
    fake=gen.forward(fake_noise)
    # output of discriminator on fake images
    out_disc_fake=disc.forward(fake).reshape(batch_size)
    # discriminator loss 
    loss_disc_real=criterion(out_disc_real,real_label)
    loss_disc_fake=criterion(out_disc_fake,fake_label)
    total_Dloss=loss_disc_real+loss_disc_fake/2
    # discriminator training
    disc_optim.zero_grad()
    total_Dloss.backward(retain_graph=True)
    disc_optim.step()
    # output of discriminator on fake images
    out_disc_fake=disc.forward(fake).reshape(batch_size)
    # generator loss
    loss_gen=criterion(out_disc_fake,real_label)
    # generator training
    gen_optim.zero_grad()
    loss_gen.backward()
    gen_optim.step()


    # tensorboard
    if i == 0:
            print(
                f"Epoch [{epoch}/{n_epoch}] Batch {i}/{len(loader)} \
                      Loss D: {total_Dloss:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fake_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)


                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1



    
    

    

    
   
    

Epoch [0/5] Batch 0/1875                       Loss D: 1.0057, loss G: 0.7209
Epoch [1/5] Batch 0/1875                       Loss D: 0.3974, loss G: 1.1634
Epoch [2/5] Batch 0/1875                       Loss D: 0.6470, loss G: 1.0462
Epoch [3/5] Batch 0/1875                       Loss D: 1.1294, loss G: 0.4322
Epoch [4/5] Batch 0/1875                       Loss D: 0.8507, loss G: 0.5516
