In [None]:
!pip install pytorch_lightning # To make code shorter

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.5-py3-none-any.whl (802 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m802.3/802.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.13.0->

In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import pytorch_lightning as pl

random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE = 128
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count()/2)



In [None]:
class MNISTDataModule(pl.LightningDataModule):
  def __init__(self, data_dir = "./data", batch_size = BATCH_SIZE, num_workers = NUM_WORKERS):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
  def prepare_data(self):
    MNIST(self.data_dir, train = True, download = True)
    MNIST(self.data_dir, train = False, download = True)

  def setup(self, stage = None):
    # Assign train/val datasets
    if stage== "fit" or stage is None:
      mnist_full = MNIST(self.data_dir, train= True, transform = self.transform)
      self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
    if stage =="test" or stage is None:
      self.mnist_test = MNIST(self.data_dir, train= False, transform = self.transform)
  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size = self.batch_size, num_workers = self.num_workers)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size = self.batch_size, num_workers = self.num_workers)

  def test_dataloader(self):
    return DataLoader(self.mnist_test, batch_size = self.batch_size, num_workers = self.num_workers)


In [None]:
# Detective: fake or real --> 1 output[0,1]
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size = 5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size = 5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 1) # 1 ; the only output


  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))

    # Flatten the tensor so it can be fed into the FC layers
    x = x.view(-1,320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training = self.training)
    x = self.fc2(x)
    return torch.sigmoid(x)




In [None]:
# Generate Fake Data : Output like real data [1,28,28] and values -1,1
# Based on latent_dim value, we upsample an outpit that is in the shape of original images [1,28,28] and values -1,1

class Generator(nn.Module):
  def __init__(self, latent_dim):
    super().__init__()
    self.lin1 = nn.Linear(latent_dim, 7*7*64) # [n,256, 7,7]
    self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride = 2) # [n, 64, 16, 16]
    self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride = 2) # [n, 16, 34, 34]
    self.conv = nn.Conv2d(16, 1, kernel_size = 7) # [n, 1, 28, 28]


  def forward(self, x):
    # Pass latent space input into linear layer and reshape
    x = self.lin1(x)
    x = F.relu(x)
    x = x.view(-1, 64, 7, 7) # 256

    # Upsample (transposed conv) 16x16 (64 feature maps)
    x = self.ct1(x)
    x = F.relu(x)
    # Upsample to 34x34 (16 feature maps)
    x = self.ct2(x)
    x = F.relu(x)

    # Convolution to 28x28(1 feature map)
    return self.conv(x)

In [None]:

class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, lr=0.0002):
        super().__init__()
        self.save_hyperparameters()
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator()
        self.validation_z = torch.randn(6, self.hparams.latent_dim)
        self.automatic_optimization = False  # Enable manual optimization

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx):
        real_imgs, _ = batch
        opt_g, opt_d = self.optimizers()  # Access optimizers

        # Sample noise
        z = torch.randn(real_imgs.shape[0], self.hparams.latent_dim).type_as(real_imgs)

        # Train Generator: max log(D(G(z)))
        fake_imgs = self(z)  # Generate fake images
        y_hat = self.discriminator(fake_imgs)
        y = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
        g_loss = self.adversarial_loss(y_hat, y)

        # Optimize Generator
        opt_g.zero_grad()
        self.manual_backward(g_loss)
        opt_g.step()

        # Train Discriminator: max log D(x) + log (1 - D(G(z)))
        y_hat_real = self.discriminator(real_imgs)
        y_real = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
        real_loss = self.adversarial_loss(y_hat_real, y_real)

        y_hat_fake = self.discriminator(fake_imgs.detach())
        y_fake = torch.zeros(real_imgs.size(0), 1).type_as(real_imgs)
        fake_loss = self.adversarial_loss(y_hat_fake, y_fake)

        d_loss = (real_loss + fake_loss) / 2

        # Optimize Discriminator
        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()

        log_dict = {"g_loss": g_loss, "d_loss": d_loss}
        self.log_dict(log_dict)

    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [opt_g, opt_d]

    def plot_imgs(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)
        sample_imgs = self(z).cpu()

        print("epoch", self.current_epoch)
        fig = plt.figure()  # Fixed typo: fiqure -> figure
        for i in range(sample_imgs.size(0)):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(sample_imgs.detach()[i, 0, :, :], cmap="gray_r", interpolation="none")
            plt.title("Generated Data")
            plt.xticks([])
            plt.yticks([])
            plt.axis("off")
        plt.show()


In [None]:
Data_Module = MNISTDataModule()
model = GAN()

In [None]:
model.plot_imgs

In [None]:
trainer = pl.Trainer(max_epochs = 20, devices = AVAIL_GPUS)
trainer.fit(model,Data_Module)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 358 K 
1 | discriminator | Discriminator | 21.4 K
------------------------------------------------
379 K     Trainable params
0         Non-trainable params
379 K     Total params
1.520     Total estimated model params size (MB)
  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.
