**Credit:** 

This notebook is for practice and taken from Mathematical foundation for Generative AI course, IITM BS 

Original Tutorial Link: [W2T5: Tutorial: Implementation of GAN](https://www.youtube.com/watch?v=iOb8vmlJd8o&t=1833s) 


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Transform: Normalize images between [-1, 1] (because Tanh will be used as output)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim=100, out_img_dim=784):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.Relu(True),
            nn.Linear(256, 512),
            nn.Relu(True),
            nn.Linear(512, 1024),
            nn.Relu(True),
            nn.Linear(1024, out_img_dim),
            nn.Tanh()   # Because we normalize images [-1, 1]
        )
    def forward(self, z):
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_img_dim=784):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_img_dim, 512),
            nn.LeakyRelu(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyRelu(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()    # Outputs probability between 0 and 1
        )
    def forward(self, img):
        return self.model(img)

In [None]:
noise_dim = 100
img_dim = 28 * 28

# Models
generator = Generator(noise_dim, img_dim).to(device)
discriminator = Discriminator(img_dim).to(device)

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

# Loss
criterion = nn.BCELoss()

In [None]:
def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake_imgs = generator(fixed_noise).reshape(-1, 1, 28, 28)
        fake_imgs = fake_imgs * 0.5 + 0.5   # de normalize
    grid = torchvision.utils.make_grid(fake_imgs, nrow=8)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()
    generator.train()

In [None]:
# 1. Standard: One step generator, One step discriminator
print("Training: One step gen, One step dis")
train_gan(train_loader, num_epochs=50, mode="one_one")

# 2.