In [56]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision

In [57]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(nn.Linear(img_dim, 128),
                                 nn.LeakyReLU(0.1),
                                 nn.Linear(128, 1),
                                 nn.Sigmoid())
    def forward(self, x):
        return self.disc(x)
        
        
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(nn.Linear(z_dim, 256),
                                nn.LeakyReLU(0.1),
                                nn.Linear(256, img_dim),
                                nn.Tanh())
    def forward(self, x):
        return self.gen(x)

**HyperParameters**

In [58]:
device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 3e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 32
num_epochs = 50

In [59]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

In [60]:
fixed_noise = torch.randn(batch_size, z_dim).to(device)

In [61]:
transforms = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),])

In [62]:
dataset = datasets.MNIST(root=r"C:\Users\sankalp\Desktop\pytorch\Data", transform=transforms, download=True)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [63]:
loss_function = nn.BCELoss()
opt_disc = optim.Adam(disc.parameters(), lr=learning_rate)
opt_gen = optim.Adam(disc.parameters(), lr=learning_rate)

In [64]:
writer_fake = SummaryWriter(f'runs/Gan_MNIST/fake')
writer_real = SummaryWriter(f'runs/Gan_MNIST/real')
step = 0

In [65]:
for epoch in range(num_epochs):
    for batch_index, data in enumerate(loader):
        real, _ = data
        real = real.view(-1,28*28*1).to(device)
        batch_size = real.shape[0]
        
        #Training Discriminator   max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = loss_function(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = loss_function(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_fake + lossD_real) / 2
        
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        #Trainning Generator   min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).view(-1)
        lossG = loss_function(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_index == 0:
            print(f'Epoch[{epoch}/{num_epochs}], LossG:{lossG:4f}')
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch[0/50], LossG:0.699986
Epoch[1/50], LossG:0.735812
Epoch[2/50], LossG:0.729273
Epoch[3/50], LossG:0.732023
Epoch[4/50], LossG:0.731040
Epoch[5/50], LossG:0.729534
Epoch[6/50], LossG:0.732191
Epoch[7/50], LossG:0.731639
Epoch[8/50], LossG:0.733897
Epoch[9/50], LossG:0.734918


KeyboardInterrupt: 