PyTorch go-to choice for researchers and practitioners for building deep learning models.
- It has some inherient problems like:
    1. Managing training loops
    2. Logging
    3. Handling Distributed training
    4. Debugging in a distributed setting
    5. Mixed precision training
    6. Running models on TPUs

In [1]:
import torch
import torch.nn as nn

In [2]:
# Define the dataset
train_dataset = torch.utils.data.TensorDataset(torch.randn(1000, 2), torch.randint(0, 10, (1000,)))
test_dataset = torch.utils.data.TensorDataset(torch.randn(1000, 2), torch.randint(0, 10, (1000,)))
val_dataset = torch.utils.data.TensorDataset(torch.randn(1000, 2), torch.randint(0, 10, (1000,)))

# Defining the model class inherited from PyTorch's nn.Module class
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the evaluate function with proper batching
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# we have declared the model without backward pass explicitly
epochs = 5
model = MyModel()
# Defining the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in trainloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    accuracy = evaluate(model)
    # print the loss and accuracy
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}, Accuracy: {accuracy}")

Epoch 1, Loss: 2.3688920587301254, Accuracy: 0.104
Epoch 2, Loss: 2.3498924151062965, Accuracy: 0.097
Epoch 3, Loss: 2.335986942052841, Accuracy: 0.099
Epoch 4, Loss: 2.3231282085180283, Accuracy: 0.101
Epoch 5, Loss: 2.3174316734075546, Accuracy: 0.099


### As discussed above there are problems with PyTorch

# PyTorch Lighting 
- Solves the above discussed challenges

- Managing training loops: PyTorch Lightning simplifies this process by providing a high-level abstraction for defining the training loop, reducing the amount of boilerplate code required.
- Logging: PyTorch Lightning integrates with popular logging frameworks like TensorBoard and Comet, making it easier to log training metrics and visualize them in real-time.
Handling distributed training: PyTorch Lightning simplifies distributed training by providing a unified interface. This abstracts away the complexity of the underlying implementation.
- Debugging in a distributed setting: PyTorch Lightning provides tools and utilities to facilitate debugging in a distributed setting, making it easier to identify and resolve issues.
- Mixed-precision training: PyTorch Lightning simplifies mixed-precision training by providing utilities to automatically handle the precision of operations based on user-defined settings.
- Running models on TPUs: PyTorch Lightning supports running models on TPUs, abstracting away the complexity of the underlying TPU architecture and allowing users to focus on their model implementation.

# PyTorch to PyTorch Lightning

In [3]:
# Definint the PyTorch model

# importing the required packages and libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.utils.data import DataLoader

In [4]:
## Load the dataset

# 1) Data transformer
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# 2) Create Train dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, 
                                      download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# 3) Create Test dataset
testset = torchvision.datasets.MNIST(root='./data', train=False, 
                                     download=True, transform=transform)

# 4) Create DataLoader
testloader = DataLoader(testset, batch_size=64, shuffle=True)


In [5]:
# now that the data is loaded, we can define the model

class PyTorchNet(nn.Module):
    
    # Defining the architecture of the model
    def __init__(self):
        super(PyTorchNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 128)
        self.fc4 = nn.Linear(128, 10)

    # Defining the forward pass
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
# Initialzie the model and define the loss function and optimizer
model = PyTorchNet()

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define a evaluation methods
def evaluate(model):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [6]:
## Train the model

# Define the number of epochs
epochs = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for data in trainloader:

        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    accuracy = evaluate(model)

    print(f"""Epoch {epoch + 1},
          Loss: {running_loss / len(trainloader)},
          Accuracy: {accuracy}
          """)



Epoch 1,
          Loss: 0.29247100143306165,
          Accuracy: 0.9276
          
Epoch 2,
          Loss: 0.14089426057470808,
          Accuracy: 0.9632
          
Epoch 3,
          Loss: 0.107698312252889,
          Accuracy: 0.9661
          
Epoch 4,
          Loss: 0.09020289394017587,
          Accuracy: 0.9727
          
Epoch 5,
          Loss: 0.0748992892322756,
          Accuracy: 0.9729
          


In [7]:
#%pip install lightning

## PyTorch Lightning Model

In [8]:
import torch 
import torch.nn as nn
import pytorch_lightning as pl


class PyTorchLightningNet(pl.LightningModule):

    def __init__(self):
        super(PyTorchLightningNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 128)
        self.fc4 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
    
model = PyTorchLightningNet()
# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define a evaluation methods
def evaluate(model):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [9]:
## Train the model

# Define the number of epochs
epochs = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for data in trainloader:

        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    accuracy = evaluate(model)

    print(f"""Epoch {epoch + 1},
          Loss: {running_loss / len(trainloader)},
          Accuracy: {accuracy}
          """)


Epoch 1,
          Loss: 0.30681939999352514,
          Accuracy: 0.9312
          
Epoch 2,
          Loss: 0.1400705497435876,
          Accuracy: 0.9451
          
Epoch 3,
          Loss: 0.10690825645710582,
          Accuracy: 0.9624
          
Epoch 4,
          Loss: 0.09134310164175101,
          Accuracy: 0.9683
          
Epoch 5,
          Loss: 0.07344570693729727,
          Accuracy: 0.9702
          


#### Looks like there is more variation in the accuracy

To do this we can add things to out Lightining model class

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics
import os


# Empty MPS cache before training (helps free up memory)
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

# Data transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

class TrainingProgressCallback(pl.Callback):

    def on_train_start(self, trainer, pl_module):
        print("Starting training!!")

# Define Lightning model
class MNISTLightningModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),  # First FC layer
            nn.ReLU(),
            nn.Linear(512, 1024),  # Second FC layer
            nn.ReLU(),
            nn.Linear(1024, 256),  # Third FC layer
            nn.ReLU(),
            nn.Linear(256, 128),  # Fourth FC layer
            nn.ReLU(),
            nn.Linear(128, 10)  # Output layer
        )
        self.criterion = nn.CrossEntropyLoss()

        # Define torchmetrics for tracking
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.precision = torchmetrics.Precision(task="multiclass", num_classes=10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        # Compute metrics (convert logits to class predictions)
        preds = outputs.argmax(dim=1)

        acc = self.accuracy(preds, labels)
        prec = self.precision(preds, labels)

        # Log metrics separately
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_precision', prec, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        preds = outputs.argmax(dim=1)
        acc = self.accuracy(preds, labels)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        preds = outputs.argmax(dim=1)
        acc = self.accuracy(preds, labels)

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        return train_loader

    def val_dataloader(self):
        return val_loader

    def test_dataloader(self):
        return test_loader


# Initialize model and trainer
model = MNISTLightningModel()
trainer = pl.Trainer(max_epochs=4,
                     accelerator= "cpu",
                     precision=16,
                     log_every_n_steps=10,
                     callbacks=[TrainingProgressCallback()])

# Train the model
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model)


/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | model     | Sequential          | 1.2 M  | train
1 | criterion | CrossEntropyLoss    | 0      | train
2 | accuracy  | MulticlassAccuracy  | 0      | train
3 | precision | MulticlassPrecision | 0      | train
--------------------

                                                                           

/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Starting training!!
Epoch 3: 100%|██████████| 938/938 [00:55<00:00, 17.03it/s, v_num=17, train_loss_step=0.0485, train_acc_step=1.000, train_precision_step=1.000, val_loss=0.114, val_acc=0.966, train_loss_epoch=0.0923, train_acc_epoch=0.972, train_precision_epoch=0.972]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 938/938 [00:55<00:00, 17.02it/s, v_num=17, train_loss_step=0.0485, train_acc_step=1.000, train_precision_step=1.000, val_loss=0.114, val_acc=0.966, train_loss_epoch=0.0923, train_acc_epoch=0.972, train_precision_epoch=0.972]


/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 181.52it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9656000137329102
        test_loss           0.11355363577604294
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.11355363577604294, 'test_acc': 0.9656000137329102}]

In [21]:
# trainer.predict(model, val_loader)

In [15]:
#%pip install torchmetrics

## Convenient features
1. Callbacks
2. profiling
3. Mixed precision training