## ML Training at SageMaker Training Job (with W&B)

This is an example notebook to invoke SageMaker Training Job service for MNIST classification ML training. 

In the notebook, we will work through SageMaker Execution Role creation, W&B secret creation (for API key) and invoke SageMaker Training Job service. 

> For create IAM role, S3 bucket, secret and pass the created SageMaker Execution role, you will need proper IAM permissions for your client side. You may use AWS Managed Policy on your identity principal (IAM user / role) for the testing this example.  


### Setup

Enable scripts loading dynamically.

In [None]:
# to load scripts dynamically
%load_ext autoreload
%autoreload 2

Load environment variables from `.env` file.

In [None]:
import boto3
import json
import os

from botocore.exceptions import ClientError

from dotenv import load_dotenv
load_dotenv("../../.env")

Create SageMaker Execution Role if it doesn't exist. The key permissions for the IAM role is to:
* download training image from ECR
* read wandb secrets
* add tags to sagemaker training job for marking wandb project & checkpoint so as for training resiliency 
* CloudWatch logs and metrics
* S3 output path

In [None]:
from scripts.utils import create_sagemaker_execution_role, create_wandb_secret, create_s3_bucket

iam_role = create_sagemaker_execution_role("sagemaker-execution-role")

Create WANDB secret on AWS Secret Manager, which will be used in Training Job for integration on ML experimentation, tracking and checkpoint storage.

In [None]:
# creating api key secret
wandb_secret_name = "wandb-secret"
create_wandb_secret(wandb_secret_name, os.environ.get("WANDB_API_KEY"))

Create S3 bucket for SageMaker Training Job output. 
* Please ensure that the bucket naming pattern aligned with IAM role (by `create_sagemaker_execution_role` function) permissions.
* Reference - with key word `sagemaker`:

```
{
                "Sid": "AllowS3ObjectActions",
                "Effect": "Allow",
                "Action": [
                    "s3:GetObject",
                    "s3:PutObject",
                    "s3:DeleteObject",
                    "s3:AbortMultipartUpload"
                ],
                "Resource": [
                    "arn:aws:s3:::*SageMaker*",
                    "arn:aws:s3:::*Sagemaker*",
                    "arn:aws:s3:::*sagemaker*"
                ]
            },
```

In [None]:
# by default, creating the bucket in us-east-1 region without providing region parameter.
bucket_name = "sagemaker-wandb-samples"
create_s3_bucket(bucket_name)

### Setup SageMaker Training Job

Please note that we are using SageMaker PyTorch pre-built training container for our ML training runtime. 

**Runtime customization**

In my example, I need to install Python packages (including [PyTorch Lightning](https://lightning.ai/) and [Weights & Bias](https://wandb.ai/)) (aka. WANDB) for my runtime. To do so, I specify `source_dir` in  SageMaker Estimator `Pytorch`, which may pick up the [requirements.txt](./src/requirements.txt) for environment customization before running the training script.

Meanwhile, [docker](./docker/) provides a reference on how to bake a custom image for ML training. In our example, we use [requirements.txt](./src/requirements.txt) for simplicity.

**WANDB Integration**

The purpose of the integration is to make ML experimentation on Amazon SageMaker, e.g. linage tracking for ML experiment, checkpoints, etc. To do so, I've create a WANDB secret (api key) in [AWS Secret Manager](https://docs.aws.amazon.com/secretsmanager/latest/userguide/intro.html). Before executing the training, it will enable WANDB integration for ML training with PyTroch Lightning. 

**Resume ML training when Cluster repairs for GPU errors**

For deep learning model training, GPU errors are common and you may want to enable checkpoints so as to resume training once Sagemaker Cluster repair is done. To learn more about [Cluster repairs for GPU errors](https://docs.aws.amazon.com/sagemaker/latest/dg/model-checkpoints-cluster-repair.html).

There are two options:

> This notebook showcase option 2.

1. Enable `checkpoint_s3_uri` parameter when creating a SageMaker Estimator object. With that, you could save ML training checkpoints into local folder `/opt/ml/checkpoints/`, which will be sync to the checkpoint S3 uri. When hardware repair occurs, SageMaker Training Job will automatically pull the latest checkpoint from the checkpoint S3 uri, which can be loaded to resume ML training.

2. Downloading a latest checkpoints from WANDB so as to resume ML training. To do so, we are using `SageMakerTrainingJobTaggingCallback` in [training script](./src/train.py) to mark WANDB ML experimentation information, including `entity`, `project` and `checkpoint_name`, into tags associated with the training job. Such details will be useful when resuming ML training from WANDB checkpoint.

In addition, when you are keen to resume ML training from a specific checkpoint from WANDB, you may specifying below variables in `environment` parameter on SageMaker Estimator:
* `WANDB_SECRET_NAME` - the secret name
* `AWS_DEFAULT_REGION` - the dedicated region for the secret. e.g. us-east-1
* `WANDB_PROJECT` - the project in your WANDB workspace
* [optional] `WANDB_CHECKPOINT_NAME` - if specified, it will be used to resume ML training.
* [optional] `WANDB_CHECKPOINT_TAG` - if specified, the associated checkpoint will be used, otherwise, using 'latest' checkpoint.

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

session = sagemaker.Session()
region = session.boto_region_name

instance_type = 'ml.g6.xlarge'
training_job_output = f"s3://{bucket_name}/training-jobs/"

# image uri when using Bring Your Own Container
# image_uri = f"{AWS_ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/mnist-training:latest"

estimator = PyTorch(
    # image_uri=image_uri,
    framework_version="2.7",
    py_version="py312",
    entry_point="train.py",
    source_dir="./src",
    role=iam_role,
    instance_type=instance_type, 
    instance_count=1,
    volume_size=50,
    output_path=training_job_output,
    hyperparameters={
        "epochs": 5
    }, 
    environment={
        "WANDB_SECRET_NAME": wandb_secret_name,
        "WANDB_PROJECT": "MNIST",
        "AWS_DEFAULT_REGION": region, # for training script to access region-based resources - secret.
        # "WANDB_CHECKPOINT_NAME": 
        # "WANDB_CHECKPOINT_TAG": "latest"
    },
)

Kick off SageMaker Training Job. 

> We didn't provide data channel for the training as we are running a sample on `MNIST` classification problem, with which we will download training/evaluation/training data from torchvision dataset.

In [None]:
estimator.fit(wait=False)

To observe how a training job runs, please refer this [aws repost thread](https://repost.aws/knowledge-center/sagemaker-training-job-errors) for more details, e.g. checking related CloudWatch logs. 