In [53]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import pytorch_lightning as pl

## GAN - Discriminator Module

In [54]:
# Detective: fake or no fake -> 1 output [0, 1]
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # Simple CNN
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)
  
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        # Flatten the tensor so it can be fed into the FC layers
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.sigmoid(x)

## GAN - Generator Module

In [55]:
# Generate Fake Data: output like real data [1, 28, 28] and values -1, 1
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7*7*64)  # [n, 256, 7, 7]
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2) # [n, 64, 16, 16]
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2) # [n, 16, 34, 34]
        self.conv = nn.Conv2d(16, 1, kernel_size=7)  # [n, 1, 28, 28]
    

    def forward(self, x):
        # Pass latent space input into linear layer and reshape
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7)  #256
        
        # Upsample (transposed conv) 16x16 (64 feature maps)
        x = self.ct1(x)
        x = F.relu(x)
        
        # Upsample to 34x34 (16 feature maps)
        x = self.ct2(x)
        x = F.relu(x)
        
        # Convolution to 28x28 (1 feature map)
        return self.conv(x)

## GAN Module

In [56]:
class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, img_size=28, channels=1, batch_size=128, lr=0.0002, num_epochs=50):
        
        super(GAN, self).__init__()
        
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.channels = channels
        self.batch_size = batch_size
        self.lr = lr
        self.num_epochs = num_epochs        
        self.generator = Generator(latent_dim)
        self.discriminator = Discriminator()
        
        #random noise
        self.validation_z = torch.randn(batch_size, latent_dim)
        
    
    def forward(self, z):
        
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_imgs, _ = batch
        
        #sample noise
        z = torch.randn(real_imgs.shape[0], self.latent_dim)
        z = z.type_as(real_imgs)
        
        #train generator max log(D(G(z)))
        if optimizer_idx == 0:
            #generate images
            gen_imgs = self.generator(z)
            y_hat = self.discriminator(gen_imgs)
            
            y = torch.ones(real_imgs.size(0), 1)
            y = y.type_as(real_imgs)
            
            g_loss = self.adversarial_loss(y_hat, y)
            self.log("g_loss", g_loss, prog_bar=True)
            
            return g_loss
        
        #train discriminator min log(D(x)) + log(1 - D(G(z)))
        if optimizer_idx == 1:
            #train with real images
            y_hat_real = self.discriminator(real_imgs)
            y_real = torch.ones(real_imgs.size(0), 1)
            y_real = y_real.type_as(real_imgs)
            
            d_real_loss = self.adversarial_loss(y_hat_real, y_real)
            
            y_hat_fake = self.discriminator(self.generator(z).detach())
            y_fake = torch.zeros(real_imgs.size(0), 1)
            y_fake = y_fake.type_as(real_imgs)
            
            d_fake_loss = self.adversarial_loss(y_hat_fake, y_fake)
            
            d_loss = (d_real_loss + d_fake_loss) / 2
            
            self.log("d_loss", d_loss, prog_bar=True)
            log_dict = {"loss": d_loss, "d_real_loss": d_real_loss, "d_fake_loss": d_fake_loss}
            return {"loss": d_loss, "progress_bar": log_dict, "log": log_dict}
    
    def configure_optimizers(self):
        lr = self.lr
        b1 = 0.5
        b2 = 0.999
        
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        
        return [opt_g, opt_d], []
    
    def on_epoch_end(self):
        self.validation_z = torch.randn(self.batch_size, self.latent_dim)
        
    def plot_images(self):
        z = self.validation_z.type_as(self.generator.fc1.weight)
        samples = self.generator(z)
        samples = samples.cpu().detach()
        samples = samples.view(samples.size(0), -1, self.img_size, self.img_size)
        
        print(f"epoch: {self.current_epoch}")
        for i in range(samples.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(samples[i].permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
        plt.show()
        
    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)

## DataLoader Module

In [57]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./data",
        batch_size: int = 64 if torch.cuda.is_available() else 128,
        num_workers: int = 1,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataModule = MNISTDataModule()
gan = GAN(*dataModule.dims)

trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=5)

trainer.fit(gan, dataModule)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\sriuj\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.


RuntimeError: Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx` argument from `training_step`, set `self.automatic_optimization = False` and access your optimizers in `training_step` with `opt1, opt2, ... = self.optimizers()`.