In [1]:
%pip install pytorch-lightning
%pip install torchvision

Collecting pytorch-lightning
  Obtaining dependency information for pytorch-lightning from https://files.pythonhosted.org/packages/a8/e3/b8b52f9f2913774eaed72a580645255e8c06004c52b2e87643b30d622abc/pytorch_lightning-2.0.9.post0-py3-none-any.whl.metadata
  Downloading pytorch_lightning-2.0.9.post0-py3-none-any.whl.metadata (23 kB)
Collecting PyYAML>=5.4 (from pytorch-lightning)
  Obtaining dependency information for PyYAML>=5.4 from https://files.pythonhosted.org/packages/84/4d/82704d1ab9290b03da94e6425f5e87396b999fd7eb8e08f3a92c158402bf/PyYAML-6.0.1-cp39-cp39-win_amd64.whl.metadata
  Downloading PyYAML-6.0.1-cp39-cp39-win_amd64.whl.metadata (2.1 kB)
Collecting fsspec[http]>2021.06.0 (from pytorch-lightning)
  Obtaining dependency information for fsspec[http]>2021.06.0 from https://files.pythonhosted.org/packages/fe/d3/e1aa96437d944fbb9cc95d0316e25583886e9cd9e6adc07baad943524eda/fsspec-2023.9.2-py3-none-any.whl.metadata
  Downloading fsspec-2023.9.2-py3-none-any.whl.metadata (6.7 kB)
Co



Collecting torchvisionNote: you may need to restart the kernel to use updated packages.





  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/ec/36/1ecc19249def521b3b948baee32903148b1f399d2dd5a9a5692942e8383c/torchvision-0.16.0-cp39-cp39-win_amd64.whl.metadata
  Downloading torchvision-0.16.0-cp39-cp39-win_amd64.whl.metadata (6.6 kB)
Collecting torch==2.1.0 (from torchvision)
  Obtaining dependency information for torch==2.1.0 from https://files.pythonhosted.org/packages/67/0a/b6dddafbb64d3ca13078a2616a2ea02c595da832586898a7eb414cf7ad10/torch-2.1.0-cp39-cp39-win_amd64.whl.metadata
  Downloading torch-2.1.0-cp39-cp39-win_amd64.whl.metadata (24 kB)
Downloading torchvision-0.16.0-cp39-cp39-win_amd64.whl (1.3 MB)
   ---------------------------------------- 0.0/1.3 MB ? eta -:--:--
   --------- ------------------------------ 0.3/1.3 MB 6.3 MB/s eta 0:00:01
   ------------------ --------------------- 0.6/1.3 MB 7.5 MB/s eta 0:00:01
   ------------------------------ --------- 1.0/1.3 MB 7.7 MB/s eta 0:00:01
   --------------------------

In [2]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [3]:
random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE=128
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count() / 2)

In [4]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data", 
                 batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [5]:
# Detective: fake or no fake -> 1 output [0, 1]
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # Simple CNN
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)
  
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        # Flatten the tensor so it can be fed into the FC layers
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [6]:
# Generate Fake Data: output like real data [1, 28, 28] and values -1, 1
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7*7*64)  # [n, 256, 7, 7]
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2) # [n, 64, 16, 16]
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2) # [n, 16, 34, 34]
        self.conv = nn.Conv2d(16, 1, kernel_size=7)  # [n, 1, 28, 28]
    

    def forward(self, x):
        # Pass latent space input into linear layer and reshape
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7)  #256
        
        # Upsample (transposed conv) 16x16 (64 feature maps)
        x = self.ct1(x)
        x = F.relu(x)
        
        # Upsample to 34x34 (16 feature maps)
        x = self.ct2(x)
        x = F.relu(x)
        
        # Convolution to 28x28 (1 feature map)
        return self.conv(x)

In [None]:
# GAN
class GAN(pl.LightningDataModule):
    def __init__(self, latent_dim=32, lr=0.0002):
        super().__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.lr = lr
        self.generator = Generator(self.latent_dim)
        self.discriminator = Discriminator()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        return self.generator(x)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        real, _ = batch
        result = None
        # Train Generator
        if optimizer_idx == 0:
            # Sample random points in the latent space
            # as input for the generator
            z = torch.randn(real.shape[0], self.latent_dim)
            z = z.type_as(real)

            # Generate fake images
            fake = self(z)

            # Train Generator
            result = self.adversarial_loss(self.discriminator(fake), torch.ones_like(fake))
            self.log("g_loss", result, prog_bar=True)
        # Train Discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples
            real_loss = self.adversarial_loss(self.discriminator(real), torch.ones_like(real))
            fake_loss = self.adversarial_loss(self.discriminator(fake.detach()), torch.zeros_like(fake))
            result = (real_loss + fake_loss) / 2
            self.log("d_loss", result, prog_bar=True)
        return result
    
    def configure_optimizers(self):
        lr = self.lr
        b1 = 0.5
        b2 = 0.999
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []