## PyTorch Lightning ML Training with Weights and Bias Checkpoints

> This is a quick demo on how to run ML training (with PyTorch Lightning) at local with Weights and Bias integration.

Weights and Bias (W&B) is a popular platform for ML experiment, linage checking and checkpoints storage. When using PyTorch Lightning framework, it provides native integration with W&B, which makes the the training / validation metrics logging, and checkpoint storage easy. 

In this notebook, we show case on how the integration, and how to resume training from checkpoints, which provides a good reference when moving to large-scale ML training using Amazon SageMaker Training Job.

### Setup

we are using `.env` to store environment variable(s). In our use case, we used `WANDB_API_KEY` for W&B login (authentication).

In [None]:
from dotenv import load_dotenv

load_dotenv("../.env")

# login W&B
import wandb
wandb.login()

True

### Dataloader for ML Training

We shall use a OSS image dataset - MNIST for a image classification problem. In this section, we will focus on creating `DataLoader`, which will be input for ML training.

In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root="./data/MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55_000, 5000])

training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

### Defining the model architecture

This section is to capture the model architecture for the image classification problem. We create a class `MNIST_LitModule` with providing hyperparameters (including class number, layers and learning rate, etc.)

The model architecture and configuration is a workable solution, which is not optimized for solving the image classification. We shall use it to demo the integration with W&B.

**Tips**:
* Call `self.save_hyperparameters()` in `__init__` to automatically log your hyperparameters to **W&B**.
* Call self.log in `training_step` and `validation_step` to log the metrics.

In [None]:
import lightning.pytorch as pl

import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy

class MNIST_LitModule(pl.LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        '''
        method used to define our model parameters
        '''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

        # save Hyperparameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

    
    def forward(self, x):
        '''method used for infernce input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # let's do 3 x (linear + relu)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc


    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # log loss and metric
        self.log('train_loss', loss)
        self.log('training_accuracy', acc)

        return loss
    
    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # log
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)



Define the model

In [7]:
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

### Training Callbacks

The callbacks is to enable model checkpoints configuration and show case how to capture W&B checkpoint(s) meta data on a SageMaker Training Job using `SageMakerTrainingJobTaggingCallback`. 


In [8]:
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filename="{epoch:03d}",
    monitor='val_accuracy', 
    save_top_k=-1,
    mode='max')

> Given this is an example to train at local, there is no SageMaker training jobs associated with it. However, the utility class is for a showcase purpose without real usage.

In SageMaker Training Job integration with W&B, we shall leave the train script refer to the tags and the environment variables to resume ML training.

In [None]:
from lightning.pytorch.callbacks import Callback

import boto3

sagemaker_client = boto3.client('sagemaker')

class SageMakerTrainingJobTaggingCallback(Callback):

    def __init__(self, training_job_arn:str):
        self.tagging_done = False
        self.training_job_arn = training_job_arn

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        print("on_save_checkpoint() is invoked.")
        checkpoint_name = trainer.logger._checkpoint_name
        # for saving a checkpoint for epoch '0', the checkpoint name is not generated yet.
        if checkpoint_name is not None and (not self.tagging_done): 
            self.tagging_done = True
            entity = trainer.logger._experiment.entity
            project = trainer.logger._project

            # put tags on training job
            self.tag_training_job(self.training_job_arn, entity, project, checkpoint_name)
        else:
            # checkpoint name is not generated at wandb server side yet.
            pass
    
    def tag_training_job(self, training_job_arn:str, entity: str, project: str, checkpoint_name: str):
        print("tag_training_job() is invoked")
        response = sagemaker_client.add_tags(
            ResourceArn=training_job_arn,
            Tags=[
                {'Key': "WANDB_ENTITY", "Value": entity},
                {'Key': "WANDB_PROJECT", "Value": project},
                {'Key': "WANDB_CHECKPOINT_NAME", "Value": checkpoint_name}
            ]
        )    

training_job_arn = ""
tagging_callback = SageMakerTrainingJobTaggingCallback(training_job_arn)

### ML Training

In [None]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(project="MNIST", log_model="all")

trainer = Trainer(
    logger=wandb_logger,
    callbacks=[checkpoint_callback], # adding tagging_callback to testing on how callback works on tagging training job. 
    accelerator="gpu",
    max_epochs=5
)

trainer.fit(model, training_loader, validation_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
wandb.finish()

### ML training with resuming checkpoint(s)

This section is to showcase how to download model artefact with specific checkpoint reference, and resume the ML training using a checkpoint.

In [None]:
import glob 

checkpoint_reference = "{entity}/MNIST/model-hri3trgq:v14"
download_artefact_path = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
download_artefact_path


model_files = glob.glob(f"{download_artefact_path}/*.ckpt")

trainer = Trainer(
    logger=wandb_logger,
    callbacks=[checkpoint_callback], # adding tagging_callback to testing on how callback works on tagging training job. 
    accelerator="gpu",
    max_epochs=5
)

model = MNIST_LitModule.load_from_checkpoint(model_files[0]) 
trainer.fit(model, training_loader, validation_loader)

In [None]:
wandb.finish()

### Next step

We will experiment W&B integration with SageMaker Training Job. Please refer to [SageMaker Training](./2.sagemaker-training/README.md).