## Torch lightning example - 2

In [1]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import os
from pytorch_lightning.callbacks import Callback, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import matplotlib.pyplot as plt
import random

print("Lightning version:", pl.__version__)

Lightning version: 2.3.1


In [2]:
# -----------------
# MODEL
# -----------------
class LightningMNISTClassifier(pl.LightningModule):
    """A simple neural network for classifying MNIST digits using PyTorch Lightning."""

    def __init__(self):
        """Initialize the model with three linear layers."""
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 1, 28, 28).

        Returns:
            torch.Tensor: Log-softmax probabilities for each class.
        """
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)  # Flatten the input tensor
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)  # Apply log-softmax activation
        return x

    def cross_entropy_loss(self, logits, labels):
        """
        Compute the cross-entropy loss.

        Args:
            logits (torch.Tensor): Logits from the model.
            labels (torch.Tensor): Ground truth labels.

        Returns:
            torch.Tensor: Cross-entropy loss.
        """
        return torch.nn.functional.cross_entropy(logits, labels)

    def training_step(self, train_batch, batch_idx):
        """
        Training step.

        Args:
            train_batch (tuple): Batch of data and labels.
            batch_idx (int): Batch index.

        Returns:
            torch.Tensor: Training loss.
        """
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        """
        Validation step.

        Args:
            val_batch (tuple): Batch of data and labels.
            batch_idx (int): Batch index.
        """
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        """
        Configure the optimizer.

        Returns:
            torch.optim.Optimizer: Adam optimizer.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [3]:
# ----------------
# DATA
# ----------------
# Data preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])  # Split the training data into training and validation sets
train_dataloader = DataLoader(mnist_train, batch_size=64)  # Training data loader
val_loader = DataLoader(mnist_val, batch_size=64)  # Validation data loader
test_loader = DataLoader(mnist_test, batch_size=64)  # Test data loader


In [4]:
# ----------------
# MODEL INITIALIZATION AND TRAINING
# ----------------
# Callback definition
class PrintCallback(Callback):
    """Callback to print messages at the start and end of training."""
    def on_train_start(self, trainer, pl_module):
        """Called when the train begins."""
        print("Training is started!")

    def on_train_end(self, trainer, pl_module):
        """Called when the train ends."""
        print("Training is done.")

# Create a wandb logger
wandb_logger = WandbLogger(project='mnist-project', entity='fabiocat93')

# Create early stopping callback
early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=3,  # Number of epochs with no improvement after which training will be stopped
    verbose=True,  # Verbosity mode
    mode='min'  # Mode to monitor the metric ('min' for loss)
)

# Instantiate the model
model = LightningMNISTClassifier()

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Define the trainer
trainer = pl.Trainer(
    logger=wandb_logger,  # Logger for experiment tracking
    min_epochs=1,  # Minimum number of epochs
    max_epochs=10,  # Maximum number of epochs
    log_every_n_steps=100,  # Log metrics every 100 steps
    deterministic=True,  # Ensure deterministic training
    devices=1,  # Number of devices to use
    accelerator="auto",  # Use available accelerator (GPU if available)
    # strategy="ddp",  # Uncomment if you have multiple GPUs and want to use distributed training
    accumulate_grad_batches=4,  # Accumulate gradients over 4 batches
    callbacks=[PrintCallback(), early_stopping_callback],  # List of callbacks
    check_val_every_n_epoch=1,  # Check validation metrics every epoch
    enable_checkpointing=True  # Enable checkpointing
)

# Assume train_dataloader and val_loader are already defined
trainer.fit(model, train_dataloader, val_loader)  # Train the model

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfabiocat93[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name    | Type   | Params | Mode 
-------------------------------------------
0 | layer_1 | Linear | 100 K  | train
1 | layer_2 | Linear | 33.0 K | train
2 | layer_3 | Linear | 2.6 K  | train
-------------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/fabiocat/miniconda3/envs/pl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training is started!


/Users/fabiocat/miniconda3/envs/pl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.181


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.053 >= min_delta = 0.0. New best score: 0.128


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 0.107


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.099


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.099. Signaling Trainer to stop.


Training is done.


In [None]:
# ----------------
# FUNCTION TO DO PREDICTION
# -----------
# -----
def show_random_example(model, dataset):
    """
    Show a random example from the dataset along with its predicted and actual labels.

    Args:
        model (torch.nn.Module): Trained model.
        dataset (torch.utils.data.Dataset): Dataset to sample from.
    """
    # Set the model to evaluation mode
    model.eval()
    
    # Get a random index
    idx = random.randint(0, len(dataset) - 1)
    
    # Get the image and label
    img, label = dataset[idx]
    
    # Add batch dimension and perform inference
    with torch.no_grad():
        logits = model(img.unsqueeze(0))  # Add batch dimension
        prediction = torch.argmax(logits, dim=1).item()  # Get predicted label
    
    # Convert image to numpy array for plotting
    img = img.numpy().squeeze()
    
    # Plot the image
    plt.imshow(img, cmap='gray')
    plt.title(f'Actual Label: {label}, Predicted Label: {prediction}')
    plt.show()

In [None]:
# ----------------
# INFERENCE
# ----------------
# Show a random example from the test set along with its predicted and actual labels
show_random_example(model, mnist_test)


In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Define the logger
wandb_logger = WandbLogger()

