In [1]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.0.post0-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.10.1 pytorch_lightning-2.2.0.post0 torchmetrics-1.3.1


In [55]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchvision.utils import save_image

os.makedirs("images", exist_ok=True)

class Generator(pl.LightningModule):
    def __init__(self, latent_dim=100, img_size=28, channels=1):
        super().__init__()
        self.latent_dim = latent_dim  # Define latent_dim attribute
        self.img_shape = (channels, img_size, img_size)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(pl.LightningModule):
    def __init__(self, img_size=28, channels=1):
        super().__init__()
        self.img_shape = (channels, img_size, img_size)

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, img_size=28, channels=1, lr=0.0002, b1=0.5, b2=0.999, batch_size=64, n_epochs=200, sample_interval=200):
        super().__init__()
        self.generator = Generator(latent_dim, img_size, channels)
        self.discriminator = Discriminator(img_size, channels)
        self.adversarial_loss_fn = torch.nn.BCELoss()
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.automatic_optimization = False  # Disable automatic optimization
        self.sample_interval = sample_interval  # Interval for sampling images

    def forward(self, z):
        return self.generator(z)

    def calculate_adversarial_loss(self, y_hat, y):
        return self.adversarial_loss_fn(y_hat, y)

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        valid = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # Sample noise as generator input
        z = torch.randn(imgs.shape[0], self.generator.latent_dim)

        # Train Generator
        if batch_idx % 2 == 0:  # Assuming 2 optimizers (one for generator, one for discriminator)
            gen_imgs = self(z)
            g_loss = self.calculate_adversarial_loss(self.discriminator(gen_imgs), valid)
            self.log('g_loss', g_loss)
            return g_loss

        # Train Discriminator
        else:
            real_loss = self.calculate_adversarial_loss(self.discriminator(imgs), valid)
            fake_loss = self.calculate_adversarial_loss(self.discriminator(self(z).detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            self.log('d_loss', d_loss)
            return d_loss

    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        return optimizer_G, optimizer_D

    def on_epoch_end(self):
        if self.current_epoch % self.sample_interval == 0:
            z = torch.randn(64, self.generator.latent_dim)
            gen_imgs = self.generator(z)
            save_image(gen_imgs, f"images/{self.current_epoch}.png", normalize=True)

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64, img_size=28):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        self.batch_size = batch_size
        self.img_size = img_size

    def prepare_data(self):
        MNIST(root="data/mnist", train=True, download=True)
        MNIST(root="data/mnist", train=False, download=True)

    def train_dataloader(self):
        return DataLoader(
            MNIST(root="data/mnist", train=True, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            MNIST(root="data/mnist", train=False, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            MNIST(root="data/mnist", train=False, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=False
        )

# Set hyperparameters
latent_dim = 100
img_size = 28
channels = 1
lr = 0.0002
b1 = 0.5
b2 = 0.999
batch_size = 64
n_epochs = 2
sample_interval = 1  # Save generated images every epoch

# Init model and datamodule
model = GAN(latent_dim=latent_dim, img_size=img_size, channels=channels, lr=lr, b1=b1, b2=b2, batch_size=batch_size, n_epochs=n_epochs, sample_interval=sample_interval)
dm = MNISTDataModule(batch_size=batch_size, img_size=img_size)

# Define Trainer
trainer = pl.Trainer(
    max_epochs=n_epochs,
)

# Train the model
trainer.fit(model, dm)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type          | Params
------------------------------------------------------
0 | generator           | Generator     | 1.5 M 
1 | discriminator       | Discriminator | 533 K 
2 | adversarial_loss_fn | BCELoss       | 0     
------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.174     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
