## Deep Learning ML training at local

The purpose of the notebook is to demo how to training a computer vision model at local using PyTorch Lightning, with Weights & Bias for full traceability and reproducibility.

In [None]:
from dotenv import load_dotenv
load_dotenv("../.env")

import wandb
wandb.login()

True

### Setup the dataloader

In [3]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root="./data/MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55_000, 5000])

In [4]:
training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

### Defining the model

**Tips**:
* Call `self.save_hyperparameters()` in `__init__` to automatically log your hyperparameters to **W&B**.
* Call self.log in `training_step` and `validation_step` to log the metrics.

In [5]:
import lightning.pytorch as pl


In [None]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy

class MNIST_LitModule(pl.LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        '''
        method used to define our model parameters
        '''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

        # save Hyperparameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

    
    def forward(self, x):
        '''method used for infernce input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # let's do 3 x (linear + relu)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc


    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # log loss and metric
        self.log('train_loss', loss)
        self.log('training_accuracy', acc)

        return loss
    
    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # log
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)



In [7]:
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

### Save model checkpoints

In [8]:
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filename="{epoch:03d}",
    monitor='val_accuracy', 
    save_top_k=-1,
    mode='max')

#### Logging images

In [None]:
from lightning.pytorch.callbacks import Callback

import boto3

sagemaker_client = boto3.client('sagemaker')

def tag_training_job(training_job_arn:str, entity: str, project: str, checkpoint_name: str):
    print("tag_training_job() is invoked")
    response = sagemaker_client.add_tags(
        ResourceArn=training_job_arn,
        Tags=[
            {'Key': "WANDB_ENTITY", "Value": entity},
            {'Key': "WANDB_PROJECT", "Value": project},
            {'Key': "WANDB_CHECKPOINT_NAME", "Value": checkpoint_name}
        ]
    )

class SageMakerTrainingJobTaggingCallback(Callback):

    def __init__(self, training_job_arn:str):
        self.tagging_done = False
        self.training_job_arn = training_job_arn

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        print("on_save_checkpoint() is invoked.")
        checkpoint_name = trainer.logger._checkpoint_name
        # for saving a checkpoint for epoch '0', the checkpoint name is not generated yet.
        if checkpoint_name is not None: 
            self.tagging_done = True
            entity = trainer.logger._experiment.entity
            project = trainer.logger._project

            # put tags on training job
            tag_training_job(self.training_job_arn, entity, project, checkpoint_name)
        else:
            # checkpoint name is not generated at wandb server side yet.
            pass
    
    
training_job_arn = ""
tagging_callback = SageMakerTrainingJobTaggingCallback(training_job_arn)

### Train Your Model

In [22]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(project="MNIST", log_model="all")


trainer = Trainer(
    logger=wandb_logger,
    callbacks=[tagging_callback, checkpoint_callback],
    accelerator="gpu",
    max_epochs=5
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# import glob 

# checkpoint_reference = "{entity}/MNIST/model-hri3trgq:v14"
# download_artefact_path = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# download_artefact_path


# model_files = glob.glob(f"{download_artefact_path}/*.ckpt")

# model_files[0]
# model = MNIST_LitModule.load_from_checkpoint(model_files[0]) 


In [23]:
trainer.fit(model, training_loader, validation_loader)

/home/ubuntu/workspace/deep-learning/sagemaker-training-job-wandb-samples/.venv/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ubuntu/workspace/deep-learning/sagemaker-training-job-wandb-samples/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:701: Checkpoint directory ./MNIST/eslt91sx/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | layer_1 | Linear           | 100 K  | train
1 | layer_2 | Linear           | 16.5 K | train
2 | layer_3 | Linear           | 1.3 K  | train
3 | loss    | CrossEntropyLoss | 0      | train
-----------------------------------------------------
118 K     Trainable params
0      

                                                                            

/home/ubuntu/workspace/deep-learning/sagemaker-training-job-wandb-samples/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/ubuntu/workspace/deep-learning/sagemaker-training-job-wandb-samples/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 860/860 [00:11<00:00, 75.70it/s, v_num=91sx]on_save_checkpoint() is invoked.
!!!LightningModule-Checkpoint!!!
checkpoint name None
project MNIST
name None
entity tom-5610-aws
Epoch 1: 100%|██████████| 860/860 [00:11<00:00, 75.58it/s, v_num=91sx]on_save_checkpoint() is invoked.
tag_training_job() is invoked
!!!LightningModule-Checkpoint!!!
checkpoint name model-eslt91sx
project MNIST
name None
entity tom-5610-aws
Epoch 2: 100%|██████████| 860/860 [00:11<00:00, 76.30it/s, v_num=91sx]on_save_checkpoint() is invoked.
tag_training_job() is invoked
!!!LightningModule-Checkpoint!!!
checkpoint name model-eslt91sx
project MNIST
name None
entity tom-5610-aws
Epoch 3: 100%|██████████| 860/860 [00:11<00:00, 75.79it/s, v_num=91sx]on_save_checkpoint() is invoked.
tag_training_job() is invoked
!!!LightningModule-Checkpoint!!!
checkpoint name model-eslt91sx
project MNIST
name None
entity tom-5610-aws
Epoch 4: 100%|██████████| 860/860 [00:11<00:00, 75.54it/s, v_num=91sx]on_sav

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 860/860 [00:11<00:00, 73.69it/s, v_num=91sx]


In [None]:
wandb.finish()

In [None]:
wandb_logger._checkpoint_name, wandb_logger._save_dir

In [None]:
api = wandb.Api()

collections = [
    coll for coll in api.artifact_type(type_name="model", project="MNIST").collections()
]

aliases = set()

# for coll in collections:
    # print(coll.artifacts.)

print(collections)
print("aliases", aliases)

In [None]:
artifacts = api.artifacts(type_name="model", name="{entity}/MNIST/model-ano4gslu")

if artifacts:
    print(artifacts.next().source_qualified_name)
else:
    print('not found')

In [None]:
checkpoint_reference = "{entity}/MNIST/model-ano4gslu:latest"
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model", save_dir="./checkpoint")


### processing training job tagging

In [None]:
import boto3

sagemaker_client = boto3.client('sagemaker')

def get_sagemaker_training_job_tags(training_job_arn:str):
    response = sagemaker_client.list_tags(ResourceArn=training_job_arn)
    tags = {}
    for tag in response['Tags']:
        tags[tag['Key']] = tag['Value']
    return tags

job_arn = ""
get_sagemaker_training_job_tags(job_arn)

{}