<a href="https://colab.research.google.com/github/rhithikashinodpk/workshop-git/blob/main/GAN1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch import nn, optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LR = 1e-4
Z_DIM = 64
IMG_DIM = 28 * 28
BS = 64
EPOCHS = 10
SAVE_DIR = "generated_images"
os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_feature):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_feature, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)


In [None]:
class Generator(nn.Module):
    def __init__(self, img_dim, z_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
disc=Discriminator(IMG_DIM).to(DEVICE)
gen=Generator(IMG_DIM,Z_DIM).to(DEVICE)
fix_noise=torch.randn(BS,Z_DIM).to(DEVICE)
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.FashionMNIST(
    root='dataset/',
    transform=transform,
    download=True
)
loader = DataLoader(dataset, batch_size=BS, shuffle=True)

In [None]:
disc = Discriminator(IMG_DIM).to(DEVICE)
gen = Generator(IMG_DIM, Z_DIM).to(DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=LR)
opt_gen = optim.Adam(gen.parameters(), lr=LR)
loss = nn.BCELoss()

In [None]:
import matplotlib.pyplot as plt

real, _ = next(iter(loader))  # Get a batch of real images

fig, ax = plt.subplots(5, 5, figsize=(10, 10))  # Use subplots, not subplot
plt.suptitle("Some real images")  # Correct function name is suptitle

ind = 0
for k in range(5):
    for kk in range(5):
        ax[k, kk].imshow(real[ind][0].cpu(), cmap="gray")  # Show the image
        ax[k, kk].axis('off')  # Hide axis
        ind += 1

plt.tight_layout()
plt.show()



In [None]:
def save_generated_images(fake, epoch):
    fake = fake.reshape(-1, 1, 28, 28)
    fake = (fake + 1) / 2  # Rescale [-1,1] to [0,1]
    grid = torchvision.utils.make_grid(fake[:25], nrow=5)
    plt.figure(figsize=(5, 5))
    plt.axis("off")
    plt.title(f"Epoch {epoch}")
    plt.imshow(grid.permute(1, 2, 0).cpu().detach().numpy())
    plt.savefig(f"{SAVE_DIR}/epoch_{epoch}.png")
    plt.close()

In [None]:
fix_noise = torch.randn(25, Z_DIM).to(DEVICE)  # Fixed noise for consistent visualization

for epoch in range(EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, IMG_DIM).to(DEVICE)
        batch_size = real.size(0)

        # Generate fake images
        noise = torch.randn(batch_size, Z_DIM).to(DEVICE)
        fake = gen(noise)

        # -------------------------
        # Train Discriminator
        # -------------------------
        disc_real = disc(real).view(-1)
        disc_fake = disc(fake.detach()).view(-1)

        lossD_real = loss(disc_real, torch.ones_like(disc_real))
        lossD_fake = loss(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        opt_disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # -------------------------
        # Train Generator
        # -------------------------
        output = disc(fake).view(-1)
        lossG = loss(output, torch.ones_like(output))  # Generator tries to fool discriminator

        opt_gen.zero_grad()
        lossG.backward()
        opt_gen.step()

    # Print progress after each epoch
    print(f"Epoch [{epoch+1}/{EPOCHS}] | Loss D: {lossD.item():.4f} | Loss G: {lossG.item():.4f}")

    # -------------------------
    # Save G


In [None]:

g_losses = []
d_losses = []


for epoch in range(EPOCHS):
    ...
    g_losses.append(lossG.item())
    d_losses.append(lossD.item())
    ...

# Plot after training
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator Loss")
plt.plot(d_losses, label="Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Generator and Discriminator Loss Over Epochs")
plt.legend()
plt.grid(True)
plt.show()
