In [None]:
import torch
import torch.nn as nn 
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter



In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().init()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyRelu(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.disc(x)



In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim,img_dim):
        super().init()
        self.gen = nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyRelu(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 64
num_epochs = 50



In [None]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim,image_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])


In [None]:
dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc = torch.optim.Adam(disc.parameters(),lr=lr)
opt_gen = torch.optim.Adam(ge n.parameters(),lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/fake")
writer_real = SummaryWriter(f"runs/real")
step =0

In [None]:
for epoch in range(num_epochs):
    for batch_idx,(real,_) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]

        #train_discrimminator : max log(D(real)) + log(1-D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1) #detach() : fake to be reused in generator training also
        lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad(
        lossD.backward()
        opt_disc.step()
        )

        #train_generator : max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()


        if batch_idx % 5 == 0:
            print(
                f"Epoch[{epoch}/{num_epochs)]\"
                f"Loss D : {lossD.item():.4f}\"
                f"Loss G : {lossG.item():.4f}"
            )
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                img_grid_fake = torchvision.utils.make_grid(fake,normalize = True)
                img_grid_real = torchvision.utils.make_grid(data, normalize = True)
                writer_fake.add_image("fake",img_grid_fake,step)
                
                writer_real.add_image("real",img_grid_real,step)
                step += 1
            )
        )

