Pytorch Lightning is a wrapper for ML researchers, to write less boilerplate.

1. don't have to set model.train() or model.eval() mode

2. don't have to worry about device, easily turn up gpu or tpu

3. no loss.backward() and optimizer.zero_grad()

4. no torch.zero_grad() and torch.detach()

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning import Trainer as light_trainer


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  warn(f"Failed to load image Python extension: {e}")


In [2]:
torch.cuda.is_available()

True

In [3]:
"""hyperparameters"""

input_size = 784 # images are 28x28
hidden_size = 500
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

In [4]:
# MNIST
# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

# test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

In [5]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
# examples = iter(test_loader)
# samples, labels = next(examples)
# samples.shape, labels.shape
# 100 images, only one color channel, resolution 28x28

In [7]:
# for i in range(3):
#     plt.subplot(1, 3, i+1)
#     plt.imshow(samples[i][0], cmap='gray')
#     # plt.show()

In [8]:
"""pl.LightningModule instead of nn.Module"""


class LitNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LitNet,self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)

    # no softmax, because the loss_function do it for us
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        images, labels = batch
        images = images.reshape(-1, 28*28)

        # Forward pass
        y_hat = self(images)
        loss = F.cross_entropy(y_hat, labels)

        # Navagate the lightning_logs on tensorboard
        # tensorboard_logs = {'train_loss': loss} 
        self.log("train_loss", loss)
        return {'loss': loss}



    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        return optimizer
    
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

        return train_loader
    

    # using the test dataset for validation dataset

    def validation_step(self, batch, batch_idx):
        # training_step defines the train loop.
        images, labels = batch
        images = images.reshape(-1, 28*28)

        # Forward pass
        y_hat = self(images)
        loss = F.cross_entropy(y_hat, labels)
        
        return {'val_loss': loss}
    

    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size,num_workers=4, shuffle=False)

        return val_loader
    

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # Navagate the lightning_logs on tensorboard
        # tensorboard_logs = {'val_losss': avg_loss}
        self.log("validation_loss", avg_loss)
        return {'val_loss': avg_loss}

In [9]:
torch.set_float32_matmul_precision('high')
model = LitNet(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes)
trainer = light_trainer(auto_lr_find=True, max_epochs=num_epochs,fast_dev_run=False, accelerator='gpu', devices=1)  # fast_dev_run will run a single batch to training, to test the model
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 392 K 
1 | relu | ReLU   | 0     
2 | l2   | Linear | 5.0 K 
--------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.590     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [10]:
# model = net(input_size, hidden_size, num_classes)
# model = model.to(device)

In [11]:
# loss function & optimizer

# criterion = nn.CrossEntropyLoss()

# optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)


In [12]:
# training loop
# n_total_step = len(train_loader)


# for epoch in range(num_epochs):
#     for i, (images, labels) in enumerate(train_loader):
#         # reshape from (100, 1, 28, 28) to (100, 784)
#         images = images.reshape(-1, 28*28).to(device)
#         labels = labels.to(device)

#         # forward pass
#         pred = model(images)
#         loss = criterion(pred, labels)

#         # backward pass
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if (i+1) % 100 == 0:
#             print(f'eopch {epoch+1} / {num_epochs}, step {i+1} / {n_total_step}, loss = {loss.item():.4f}')

In [13]:
# # test

# with torch.no_grad():
#     n_correct = 0
#     n_samples = 0
#     for images, labels in test_loader:
#         images = images.reshape(-1, 28*28).to(device)
#         labels = labels.to(device)
#         output = model(images)

#         _, pred = torch.max(output, 1)
#         n_samples += labels.shape[0]
#         n_correct += (pred == labels).sum().item()

In [14]:
# acc = 100.0 * n_correct / n_samples
# acc

In [15]:
#  Navagate the lightning_logs