# Workshop: Onboarding & optimizing AI/ML workloads on AWS with Amazon S3 and EKS

---

In the early phases of AI/ML project development, it’s common for individual contributors to work independently on laptops, personal servers, or cloud-based compute instances, each using their preferred IDE. A popular choice, and the one we’ll use in this workshop, is Jupyter Notebooks — a tool that offers an interactive environment for writing, testing, and refining code.

In this notebook, we’ll begin by programmatically exploring the dataset stored on Amazon S3. With Mountpoint for S3, we can seamlessly use our standard Python libraries to interact with the dataset as if it were locally available on the instance. From there, we’ll transition from data exploration to launching distributed ML training on the same S3-hosted dataset. To achieve this, we’ll initiate multiple ML jobs using Ray on Amazon EKS, demonstrating how this setup supports _efficient data I/O_ for both **reading training datasets from S3** and **writing model checkpoints back to S3**.

---

## Contents

1. Set everything up
2. Explore the dataset
3. Implement the benchmarking script
4. Launch the first dataloading benchmark on the Ray cluster
5. Shard the dataset
6. Re-run the dataloding benchmark with sharded dataset
7. Checkpoint models with Mountpoint for Amazon S3
8. Summary and conclusions
  

---

# 1. Set everything up

<a id='sec-1'></a>

---

<font size='4' color='gray'>**_A few instructions before you start_** - </font><font size='4' color='red'>**PLEASE READ THIS!**</font>

<font size='3' color='green'>**_(1) Run each of the following code cells in turn with Shift+Enter._**</font>

_It may take a few seconds to a few minutes for a code cell to run. You can determine whether a cell is running by examining the `[]:` indicator in the left margin next to each cell: a cell will show `[*]:` when running, and `[<a number>]:` when complete. **Please read on while you wait**._

_Please feel free to review the code, but it is not essential for you to understand it as the important elements will be explained._

<font size='3' color='red'>**_(2) If any cell output is in red, this indicates an error._**</font>

_Check the  cells have been run in order, and seek help from a workshop assistant if needed._

---

## 1.1 Install and import the required libraries

In [None]:
# Standard library imports
import sys
import os
import json
import time
import random
from datetime import datetime

# Third-party imports
import boto3                    # AWS SDK for Python
import ray                      # Ray SDK for Python
import ray.job_submission       # Ray Jobs interface module of Ray SDK for Python
import pandas as pd
from PIL import Image

# Local imports
from utilities import utils     # Collection of helper functions

# Version check for Python compatibility
MIN_PYTHON_VERSION = (3, 7)
assert sys.version_info >= MIN_PYTHON_VERSION, f"Python version must be {MIN_PYTHON_VERSION[0]}.{MIN_PYTHON_VERSION[1]} or higher."

# Print SDK versions
print(f"Python version: {sys.version.split()[0]}")
print(f"Boto3 SDK version: {boto3.__version__}")
print(f"Ray SDK version: {ray.__version__}")


## 1.2 Initial setup for clients and global variables

In [None]:
# AWS region and S3 bucket configuration
aws_region = os.getenv("AWS_REGION")
s3_bucket_name = os.getenv("WORKSHOP_BUCKET")
s3_bucket_prefix = "dataset"

# S3 bucket mountpoints
local_mountpoint_dir = "/s3_data"       # S3 bucket mountpoint path on this Jupyter instance
eks_mountpoint_dir = "/mnt/s3_data"     # S3 bucket mountpoint path on EKS cluster, as seen by Ray workers

# Ray client configuration
ray_head_dns = os.getenv("RAY_HEAD_NLB_DNS")
ray_head_port = 8265
ray_address = f"http://{ray_head_dns}:{ray_head_port}"

# Initialize Ray client
ray_client = ray.job_submission.JobSubmissionClient(ray_address)

# Print configurations
print(f"AWS Region: {aws_region}")
print(f"S3 Bucket: {s3_bucket_name}")
print(f"Ray Head DNS: {ray_client.get_address()}")

# 2. Explore the dataset

<a id='sec-2'></a>

Let's start by exploring the dataset you uploaded to S3. As your S3 bucket is already mounted to this machine with **Mountpoint for Amazon S3**, you can interact with the dataset as if it were available locally. Later on, you will use the very same dataset on S3 to train your ML workloads on the Ray cluster.

<img src="./assets/pic_s3_data_mountpoints.png" width="1000" align="center"/>

## 2.1 Print key statistics of the dataset

In [None]:
# Print the local path for the dataset
dataset_path = os.path.join(local_mountpoint_dir, s3_bucket_prefix, '100k-samples-small-files')
print("Local path:", dataset_path)

The dataset is comprised of **100,000 JPG images**, where each file is a training sample stored in **4 different subfolders**. This dataset can be used to train an ML model for image classification. Therefore, the folder names also serve as the corresponding class labels of the images stored in them. Let's print what subfolders (or class labels) we have:

In [None]:
# List subfolders under the dataset path, which are also the class names
dataset_classes = os.listdir(dataset_path)
dataset_classes

In [None]:
# Pick one of the classes in the dataset
class_name = dataset_classes[0]

# Print some file stats of a few random images from this class (or subfolder)
class_path = os.path.join(dataset_path, class_name)
selected_images = random.sample(os.listdir(class_path), 5)

print(f"A few random files of class '{class_name}':")
for image_name in selected_images:
    image_path = os.path.join(class_path, image_name)
    
    # Get the size on disk (in KB)
    size_on_disk_kb = os.path.getsize(image_path) / 1024
    
    # Open the image to get its dimensions
    with Image.open(image_path) as img:
        width, height = img.size
        print(f"  - {image_path}: {width}x{height} pixels, Size on disk: {size_on_disk_kb:.2f} KB")

In [None]:
# Count the number of files in each subfolder / class 
total_count = 0

# List the images paths of each class
for class_name in dataset_classes:
    class_path = os.path.join(dataset_path, class_name)
    # Count the jpg images in the class directory
    image_count = len(os.listdir(class_path))
    total_count += image_count
    print(f"Class '{class_name}': {image_count} images")

print("-" * 30)
print(f"Total: {total_count} images")

## 2.2 Visualize a few training samples

The synthetic images in this dataset were generated with Gaussian noise, where each class has distinct mean pixel values centered around 4 unique levels (one for each class). This approach is one of many techniques to rapidly create synthetic image datasets that are suitable for training machine learning algorithms for image classification. Let’s now display a few sample images from each class and overlay their average pixel values to verify the distinction between classes, as well as variation between samples withing each class:

> 🕥 _The operation may take 15 seconds to complete._

In [None]:
utils.visualize_images(dataset_path, num_images_per_class=5, image_size=(128, 128))

<br>

# 3. Implement the benchmarking script

<a id='sec-3'></a>

Execute the following cell to create the custom benchmarking script, which Ray will execute remotely as a distributed ML training job. <br>
**We will discuss this script in more detail once the job is running.**

**<font color='red'>The</font> `benchmark.py` <font color='red'>script</font>**:

> ⚠️ **_NOTE:_**  The cell below creates the actual `benchmark.py` script that is going to be executed remotely by Ray on Amazon EKS. You are encouraged to study the code, but please **_DO NOT_** change any lines of code during the workshop.

In [None]:
%%writefile scripts/benchmark.py

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import os
import glob
import json
import time
import argparse
import datetime

import numpy as np
from PIL import Image

import webdataset as wds
import s3torchconnector as s3pt

import torch
import torch.nn as nn
import torchdata
from torchvision.transforms import v2 as tvt

import ray.train
import ray.train.torch


################## BENCHMARK PARAMETERS DEFINITION ###################

def parse_args():
    
    def none_or_int(value):
        if str(value).upper() == 'NONE':
            return None
        return int(value)
    
    def none_or_str(value):
        if str(value).upper() == 'NONE':
            return None
        return str(value)
    
    def str_bool(value):
        if str(value).upper() == 'TRUE':
            return True
        elif str(value).upper() == 'FALSE':
            return False
        else:
            raise TypeError("Must be True or False.")
    
    parser = argparse.ArgumentParser()

    ### Parameters that define dataloader config
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--dataloader_workers', type=int, default=0)
    parser.add_argument('--dataloader_use_s3pt', type=str_bool, default=False)
    parser.add_argument('--prefetch_size', type=none_or_int, default=2)
    parser.add_argument('--input_dim', type=int, default=224)
    parser.add_argument('--pin_memory', type=str_bool, default=True)

    ### Parameters that define dataset config
    parser.add_argument('--dataset_path', type=str)
    parser.add_argument('--dataset_format', type=str)
    parser.add_argument('--dataset_num_samples', type=int, default=100_000)
    parser.add_argument('--dataset_region', type=none_or_str, default=os.getenv('AWS_REGION'))

    ### Parameters that define model parameters
    parser.add_argument('--model_compute_time', type=none_or_int, default=None) # in miliseconds
    parser.add_argument('--model_num_parameters', type=int, default=1) # in millions of parameters
    
    ### Parameters that define benchmark infrastructure
    parser.add_argument('--ray_workers', type=int, default=2)
    parser.add_argument('--ray_cpus_per_worker', type=int, default=8)
    parser.add_argument('--ray_use_gpu', type=str_bool, default=False)
   
    ### Parameters that define checkpointing config
    parser.add_argument('--ckpt_steps', type=int, default=0)
    parser.add_argument('--ckpt_mode', type=str, default='disk')
    parser.add_argument('--ckpt_path', type=str, default='checkpoints/')
    parser.add_argument('--ckpt_region', type=none_or_str, default=os.getenv('AWS_REGION'))

    ### Some other parameters for logging results
    parser.add_argument('--log_directory', type=none_or_str, default=os.path.join(os.getenv('EKS_MOUNTPOINT_DIR', '.'), 'logs'))
    parser.add_argument('--benchmark_name', type=none_or_str, default=f'benchmark-{datetime.datetime.now().strftime("%Y%m%d%H%M%S-%f")}')

    return parser.parse_known_args()


################## MODEL IMPLEMENTATION ###################

class ModelMock(torch.nn.Module):
    '''Model mock to emulate a computation of a training step'''
    def __init__(self, config):
        super().__init__()
        self.model = torch.nn.Linear(config.model_num_parameters * 1_000_000, 1)
        self.config = config
    
    def forward(self, data, target, epoch, step):
        if self.config.model_compute_time > 0:
            return time.sleep(self.config.model_compute_time / 1_000)

        if (
            ray.train.get_context().get_world_rank() == 0 and
            self.config.ckpt_steps > 0 and
            step % self.config.ckpt_steps == 0
        ):
            return self.save_checkpoint(epoch, step)

    def save_checkpoint(self, epoch, step):
        if self.config.ckpt_mode == 's3pt':
            return save_checkpoint_s3pt(self.model, self.config.ckpt_region, self.config.ckpt_path, epoch, step)
        elif self.config.ckpt_mode == 'disk':
            return save_checkpoint_disk(self.model, self.config.ckpt_path, epoch, step)
        else:
            raise NotImplementedError("Unknown checkpoint mode '%s'.." % self.config.ckpt_mode)


def save_checkpoint_s3pt(model, region, uri, epoch_id, step_id):
    path = os.path.join(uri, f"epoch-{epoch_id}-step-{step_id}.ckpt")
    checkpoint = s3pt.S3Checkpoint(region=region)
    start_time = time.perf_counter()
    with checkpoint.writer(path) as writer:
        torch.save(model.state_dict(), writer)
    end_time = time.perf_counter()
    save_time = end_time - start_time
    print_from_rank(f"Saving checkpoint to {uri} took {save_time} seconds..")
    return save_time

def save_checkpoint_disk(model, uri, epoch_id, step_id):
    if not os.path.exists(uri):
        os.makedirs(uri)
    path = os.path.join(uri, f"epoch-{epoch_id}-step-{step_id}.ckpt")
    start_time = time.perf_counter()
    torch.save(model.state_dict(), path)
    end_time = time.perf_counter()
    save_time = end_time - start_time
    print_from_rank(f"Saving checkpoint to {path} took {save_time} seconds..")
    return save_time
            

################## DATASET IMPLEMENTATIONS ###################

class MapDataset(torch.utils.data.Dataset):
    def __init__(self, files, transform):
        self._files = np.array(files)
        self._transform = transform
   
    @staticmethod
    def _get_label(file):
        return file.split(os.path.sep)[-2]
    
    @staticmethod
    def _read(file):
        return Image.open(file).convert('RGB')
    
    def __len__(self):
        return len(self._files)
    
    def __getitem__(self, idx):
        file = self._files[idx]
        sample = self._transform(self._read(file))
        label = int(self._get_label(file))    # Labels in [0, MAX) range
        return sample, label

def _make_pt_dataset(config, transform):
    # Create a dataset from individual image files
    
    files = glob.glob(config.dataset_path + '/**/*.jpg')
    dataset = MapDataset(files, transform)
    return dataset

def _make_wds_dataset(config, transform):
    # Create a WebDataset from tar files and apply transformations
    
    def _create_sample(sample):
        label, img = sample['__key__'], sample['jpg']
        img = transform(img)
        label = int(label.split('/')[-2])
        return img, label
    
    files = glob.glob(config.dataset_path + '/*.tar')        
    dataset = wds.WebDataset(files, shardshuffle=True, resampled=True, nodesplitter=wds.split_by_node)
    dataset = dataset.decode('pil')
    dataset = dataset.map(_create_sample)
    dataset = dataset.with_epoch(config.dataset_num_samples // (config.ray_workers * config.dataloader_workers))
    return dataset

def _make_s3pt_dataset(config, transform):

    def _tar_to_tuple(s3object):
        return s3object.key, torchdata.datapipes.utils.StreamWrapper(s3object)
    
    def _create_sample(item):
        label, img = item
        img = transform(Image.open(img).convert('RGB'))
        label = int(label.split('/')[-2])
        return img, label

    dataset = s3pt.S3IterableDataset.from_prefix(config.dataset_path, region=config.dataset_region)
    dataset = torchdata.datapipes.iter.IterableWrapper(dataset)
    if config.dataloader_workers > 0:
        dataset = dataset.sharding_filter()
    dataset = dataset.map(_tar_to_tuple)
    dataset = dataset.load_from_tar()
    dataset = dataset.map(_create_sample)
    return dataset


################## BENCHMARK IMPLEMENTATIONS #################
def build_dataloader(config):
    # Define image transformations and build the dataloader based on dataset format
    transform = tvt.Compose([
        tvt.ToImage(),
        tvt.ToDtype(torch.uint8, scale=True),
        tvt.RandomResizedCrop(size=(config.input_dim, config.input_dim), antialias=False), #antialias=True
        tvt.ToDtype(torch.float32, scale=True),
        tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Build dataset
    if config.dataset_format == 'jpg':
        dataset = _make_pt_dataset(config, transform)
    elif config.dataset_format == 'tar':
        if config.dataloader_use_s3pt:
            dataset = _make_s3pt_dataset(config, transform)
        else:
            dataset = _make_wds_dataset(config, transform)
    else:
        raise NotImplementedError("Unknown dataset format '%s'.." % config.dataset_format)


    return torch.utils.data.DataLoader(
        dataset,
        num_workers=config.dataloader_workers,
        batch_size=config.batch_size,
        prefetch_factor=config.prefetch_size,
        pin_memory=config.pin_memory
    )


def build_model(config):
    # Build a model or a model mock based on provided config
    if config.model_compute_time is not None:
        model = ModelMock(config)
    else:
        raise NotImplementedError("Need to set compute time explicitely..")
    return model


def train_model(model, dataloader, config):
    # Train model and collect metrics
    metrics = {}
    img_tot_list, ep_times, ckpt_times = [], [], []
    t_train_start = t_epoch_start = time.perf_counter()

    for epoch in range(config.epochs):
        img_tot = 0
        
        for step, (images, labels) in enumerate(dataloader, 1):

            # Perform a training step and optionally save checkpoint
            batch_size = len(images)
            img_tot += batch_size

            result = model(images, labels, epoch, step)
            
            if result:
                ckpt_times.append(result)

            if step % 50 == 0:
                print_from_rank(f"Epoch = {epoch} | Step = {step}")

        # Record metrics for each epoch
        img_tot_list.append(img_tot)
        ep_times.append(time.perf_counter() - t_epoch_start)
        t_epoch_start = time.perf_counter()

    # Summarize training metrics
    t_train_tot = time.perf_counter() - t_train_start
    metrics['training_time'] = t_train_tot
    metrics['samples_per_second'] = sum(img_tot_list) / t_train_tot
    metrics['samples_processed_total'] = sum(img_tot_list)
    metrics.update({f't_epoch_{i}': t for i, t in enumerate(ep_times, 1)})
    metrics.update({f't_ckpt_{i}': t for i, t in enumerate(ckpt_times, 1)})
    if ckpt_times:
        metrics['t_ckpt_ave'] = sum(ckpt_times) / len(ckpt_times)
    return metrics


################## HELPER FUNCTIONS #################
def print_from_rank(msg, rank=0):
    if ray.train.get_context().get_world_rank() == rank:
        print(f'[r:{rank}]:', msg)


############ MAIN EXECUTABLE FUNCTION ##############
def main_fn(config):

    # Print debugging information and configuration
    print_from_rank("Benchmarking params:\n" + json.dumps(vars(config), indent=2))
    print_from_rank("Environment variables:\n")
    for k, v in os.environ.items():
        print_from_rank(f'{k}={v}')

    # # Print example dataset files for debugging
    filelist_gen = glob.iglob(os.path.join(config.dataset_path, '**', '*'), recursive=True)
    print_from_rank(f"Files in {config.dataset_path}:")
    for i, f in enumerate(filelist_gen):
        print_from_rank(" - " + f)
        if i > 10: break
    
    # Step #1: Build dataloader and prepare it for Ray distributed environment
    dataloader = build_dataloader(config)
    dataloader = ray.train.torch.prepare_data_loader(dataloader)

    # Step #2: Build model and prepare it for Ray distributed environment
    model = build_model(config)
    model = ray.train.torch.prepare_model(model)

    # Step #3: Train the model and collect metrics
    metrics = train_model(model, dataloader, config)
    
    # Step #4: Log metrics and save to S3      
    os.makedirs(config.log_directory, exist_ok=True)
    log_file = os.path.join(config.log_directory, config.benchmark_name + '.json')
    with open(log_file, 'w') as f:
        json.dump(metrics, f)

    print_from_rank(f"Logged the following metrics to '{log_file}':\n" + json.dumps(metrics, indent=2))

    time.sleep(3)

    return
        

################## ENTRY POINT #################
if __name__ == '__main__':

    # Parse configuration arguments
    train_config, _ = parse_args()

    # Set up scaling configuration for Ray Trainer
    scaling_config = ray.train.ScalingConfig(
        num_workers=train_config.ray_workers,
        use_gpu=train_config.ray_use_gpu,
        resources_per_worker={
            'CPU': train_config.ray_cpus_per_worker
        })
    
    # Initialize Ray TorchTrainer with main function
    trainer = ray.train.torch.TorchTrainer(
        main_fn,
        scaling_config=scaling_config,
        train_loop_config=train_config)

    # Run the distributed training job
    result = trainer.fit()


<br>

# 4. Launch the first dataloading benchmark on the Ray cluster

<a id='sec-4'></a>

_It is benchmark time!_ You are about to submit your first remote distributed ML training job to the Ray cluster running on Amazon EKS. Using Ray Jobs API is an easy and recommended way of submitting locally developed applications to the Ray cluster for remote execution. All that you need to do is to:
1. compose an **entrypoint command** (like `python my_script.py`) that will be executed remotely on each Ray worker;
2. define our **runtime environment**, which specifies runtime dependencies and requirements of our executable scripts;
3. **submit the job** via Submit Job API to Ray cluster.

To evaluate the dataloading performance of our distributed training job, you'll run our benchmarking script with several key configuration parameters that control the training behavior and resource utilization:

- `epochs` - number of complete passes through the training dataset (set to `3` here);
- `batch_size` - number of training samples processed in each iteration (set to `64`);
- `prefetch_size` - number of batches to prefetch in the data loading pipeline (set to `2`);
- `input_dim` - input image dimension for the model (set to `224`);
- `dataloader_workers` - number of parallel data loading processes (set to `16`);
- `dataset_path` - location of the training dataset in the mounted shared filesystem;
- `dataset_format` - format of the input dataset (either `jpg` or `tar`);
- `model_compute_time` - in ms, artificial delay to mimick GPU computation step (set to `0` here, meaning that we iterate through our dataset as fast as CPUs can do that);
- `ray_workers` - number of Ray workers for distributed training (set to `2`);
- `ray_cpus_per_worker` - CPU cores allocated to each Ray worker (set to `8`);
- `benchmark_name` - name of the benchmark for tracking (optional).

In [None]:
### --------
### STEP #1: Compose the entrypoint command for Ray workers
### -------

# Set dataset name, dataset format, and benchmark name
dataset_name = '100k-samples-small-files'
dataset_format = 'jpg'
dataset_path = os.path.join(eks_mountpoint_dir, s3_bucket_prefix, dataset_name)
benchmark_name = f'benchmark-dataloading-{dataset_name}-{dataset_format}-{datetime.now().strftime("%Y%m%d%H%M%S-%f")}'

# Compose entrypoint command string
entrypoint_command = "python benchmark.py" \
                     "  --epochs=3" \
                     "  --batch_size=64" \
                     "  --prefetch_size=2" \
                     "  --input_dim=224" \
                     "  --dataloader_workers=16" \
                    f"  --dataset_path={dataset_path}" \
                    f"  --dataset_format={dataset_format}" \
                     "  --model_compute_time=0" \
                     "  --ray_workers=2" \
                     "  --ray_cpus_per_worker=8" \
                    f"  --benchmark_name={benchmark_name}"


### --------
### STEP #2: Define the runtime environment parameters for Ray workers
### -------

runtime_environment = {
    "working_dir": "./scripts",        # <--- the working dir is copied over to each Ray worker
    "pip": [                           # <--- PYPI packages to be installed on each Ray worker before executing entrypoint command
        'torch',
        'torchvision',
        'torchdata',
        'webdataset',
        's3torchconnector'],
    "env_vars": {                      # <--- any custom env vars to be set in Ray worker runtime
        'AWS_REGION': aws_region,
        'EKS_MOUNTPOINT_DIR': eks_mountpoint_dir
    }
}

### --------
### STEP #3: Submit job to Ray cluster
### -------

job_id = ray_client.submit_job(entrypoint=entrypoint_command, runtime_env=runtime_environment)


### Print out the Ray Job ID and other details
print(f"Submitted a new Ray job with ID '{job_id}' and the following entrypoint command: \n")
for line in entrypoint_command.split('  '):
    print(line)

## 4.1 Understanding the benchmarking script

<a id='sec-3'></a>

While our custom benchmarking script is being executed remotely by Ray, let’s quickly recap the high-level setup of the benchmark runtime environment. We will also outline key principles for designing an efficient data I/O implementation for ML training, as these concepts will be essential for our benchmark script implementation.

<img src="assets/pic_efficient_ml_training.png" width="1200"/>

**An efficient dataloading pipeline for ML workloads**

While CPU resources are relatively inexpensive and abundant, modern GPUs are relatively scarce and more expensive. Therefore, when designing an end-to-end dataloading pipeline for ML training, it is essential to avoid GPU starvation as much as possible.

The training data input pipeline implemented here incorporates some best practices, to help us keep I/O bottlenecks associated with data ingestion as low as possible, thereby minimizing GPU idle times. We leverage several features of the native PyTorch dataloader, including **dataloader parallelization, batch prefetching, and buffering**, to asynchronously overlap preprocessing tasks on the CPU with training steps on the GPU.

The **preprocessing steps** consist of JPEG decoding and image resizing to a tensor of dimensions 224x224x3 before batching the training examples into mini-batches of 64 samples. This minimum set of preprocessing operations enables us to construct a complete training pipeline while keeping CPU-bound preprocessing overhead to a minimum.

Additionally, we will utilize the **caching capabilities** of _Mountpoint for Amazon S3_, which can automatically cache training data on local storage, accelerating repetitive read requests from the dataloader. This significantly speeds up our ML training process, as we will demonstrate later.

**Multi-node distributed ML training with Ray on Amazon EKS**

In order to properly setup our training script for a multi-node distributed training environment with Ray, we utilize several utility functions from the Ray Train SDK for PyTorch. These utilities help manage resources, synchronize model weights, and optimize data loading across multiple nodes, as outlined in the [Ray Guide for PyTorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html). By leveraging these Ray capabilities, we can efficiently scale our training workload across hundreds of CPU and GPU nodes.


**❗ An important note**:

- In this workshop we are **running benchmarks on a CPU-backed instance**, since we are primarily interested in the I/O behaviour of the ML training pipeline and do not have to be training an actual model on GPUs to perform the ML Storage benchmarks. 
- We follow the common ML storage benchmark procedure of mocking (aka. simulating) the model computation step by sleeping for a pre-defined amount of time on each model training step (see, e.g. [DLIO](https://github.com/argonne-lcf/dlio_benchmark) or [MLPerf for Storage](https://github.com/mlcommons/storage) I/O benchmarking toolsets).


## 4.2 Monitoring and observability for Ray jobs

The Ray job you just submitted will take <font color='red'>**up to 10 minutes**</font> to complete. Later, you will learn how to improve performance with this dataset.

While the job is running, you will connect to the Ray Dashboard, then the Grafana Dashboard, to monitor its progress.

### 4.2.1 Ray Dashboard

- Run the cell below to generate a link to the Ray Dashboard.
- Follow the link to open the dashboard in a new browser tab.
- On the Ray Dashboard, Choose the **Jobs** tab to track the job's progress.

In [None]:
print(f"Ray Dashboard: http://{os.getenv('RAY_DASHBOARD_NLB_DNS')}")

Here’s what the **Jobs** tab looks like in the Ray Dashboard, where will see any Ray job that is running in your cluster:

<img src="assets/pic_ray_jobs.png" width="900" align="center"/>

<br>

Each job starts in the **PENDING** state, as shown here:

<img src="assets/pic_ray_pending.png" width="400" align="center"/>

<br>

As this is your first Ray job, it will take about two minutes to bootstrap the Ray environment (subsequent jobs will have much lower startup overhead time in _pending_ state), it will transition to the **RUNNING** state:

<img src="assets/pic_ray_running.png" width="400" align="center"/>

<br>

You can choose the **Job ID** to access the **Logs** window, where you can view detailed job progress and any output generated by the job. As the job will run through multiple epochs on the dataset, you can also monitor the progress in the job logs:

<img src="assets/pic_ray_logs.png" width="600" align="center"/>

> 💡 **_TIP:_** Use the **Refresh** button to keep the logs printouts updated

<br>

Additionally, the **Cluster** tab in the Ray Dashboard provides information on worker node utilization.

<img src="assets/pic_ray_cluster.png" width="900" align="center"/>

<br>

### 4.2.2 Grafana Monitoring Dashboard

The Grafana dashboard is deployed as a Pod in our EKS cluster. It is worthwhile mentioning, that Grafana can also be embedded into the Ray Dashboard directly, but due the limitations of the enbedded Grafana dashboard, it is deployed _**separarately**_ for this workshop.

- Run the cell below to generate a link to the Grafana Dashboard.
- Follow the link to open the dashboard in a new browser tab.

In [None]:
print(f"Grafana Dashboard: http://{os.getenv('GRAFANA_NLB_DNS')}")

When prompted for credentials, enter the username: **_admin_** and password: **_admin_**.

> ⚠️ _**NOTE:**  You will be prompted to select a new password. We suggest to choose **Skip** (under the **Submit** button), for simplicity._

<br>
<img src="assets/pic_grafana_open.png" width="600" align="center"/>


Once logged in, select **Dashboards** on the left navigation pane, then choose the **Workshop** dashboard, which should be the only dashboard.

<img src="assets/pic_grafana_dashboard.png" width="600" align="center"/>

You will than be able to see statistics in 5 second intervals, and monitor various task metrics of your Ray environment.

<img src="assets/pic_grafana_dashboard2.png" width="800" align="center"/>

When you're done exploring, return to the Ray Dashboard to validate the job has reached the **SUCCEEDED** state:

<img src="assets/pic_ray_succeeded.png" width="400" align="center"/>

You can also run the cell in 4.3, below, as it will automatically wait until the Ray job completes.



## 4.3 Analyze and plot results (dataloading with small files dataset)

Let's now plot the results of the benchmark that we have just performed. Since the benchmark script has created a logfile in our S3 bucket for persistence, we can load it here from our locally mounted S3 bucket and plot it with the a helper function:

> ⚠️ _**NOTE**: The Ray job you launched takes <font color='red'>**around 10 mins**</font> to complete. The cell below will automatically wait for the job to complete, and then plot the benchmark results._

In [None]:
# Wait for the job to finish, so that we can plot the results
utils.wait_for_job_to_finish(job_id, ray_client)

# Load the log file and plot the results
logfile_path = os.path.join(local_mountpoint_dir, 'logs', benchmark_name + '.json')
print(f"Plotting results from '{logfile_path}'..")
utils.plot_dataloading_results(logfile_path)

### Explanation

The benchmark results displayed in the plot above illustrate the time per epoch on our Ray cluster nodes.

The red line represents actual epoch times with data caching enabled. The first epoch takes considerably longer than subsequent ones, while the training script streams the dataset directly from S3. By the second epoch, the dataset has been fully cached on the local storage of Ray worker instances. Consequently, the data is read from the local instance storage, significantly reducing file access times and decreasing epoch durations. 

Streaming this dataset from S3 is relatively slow, due to it being comprised of a large number of small files and the high latency characteristic of S3 general purpose storage (see discussion in the next section).

To understand the impact of caching, the graph compares this with a hypothetical scenario where the dataset is streamed from S3 for each epoch, without any local caching. Although this scenario wasn’t benchmarked directly, we can estimate that each epoch would take approximately as long as the first epoch of our benchmark with enabled caching. This hypothetical scenario is depicted by the blue dashed line in the plot.

These results highlight the substantial efficiency gains from caching datasets locally, particularly when working with large collections of small files that can fit within the local storage capacity.

# 5. Shard the dataset

<a id='sec-5'></a>

In this section, we re-package our training dataset consistent of lots of small JPG files into a sharded dataset comprising of a few larger TAR files. While the task is running, we will discuss why sharding is helpful.

## 5.1 Sharding our dataset with **Amazon S3 Tar Tool**
Among various methods available for dataset sharding into TAR format, we leverage the **[Amazon S3 Tar Tool](https://github.com/awslabs/amazon-s3-tar-tool)** for this workshop. This open-source utility provides an efficient and streamlined approach to creating tarballs from existing Amazon S3 objects and storing them back to S3 through a single CLI command. 

> ⚠️ _**NOTE:**  We have pre-compiled an s3tar binary to save time in this workshop. In a production environment, you should install [make](https://www.gnu.org/software/make/), [go](https://go.dev/), and compile it from the latest build on the [**Amazon S3 Tar Tool** GitHub repository](https://github.com/awslabs/amazon-s3-tar-tool). See [detailed instructions here](https://github.com/awslabs/amazon-s3-tar-tool?tab=readme-ov-file#installation)._

In what follows, we'll execute a sharding operation on our dataset comprised of **100,000 JPG files**, each approximately **100KB** in size. The process will concatinate these files into roughly **40 TAR archives**, with each archive approximately **250MB** in size and containing approximately 2,500 training samples. The entire operation is expected to complete in <font color='red'>**around 4 mins**</font>, outputting a sharded dataset in your S3 bucket. 

> 💡 _**TIP:**  While you wait for the next cell to complete, please read through _**Section 5.2**_ to learn why sharding can be a good idea when streaming datasets directly from S3 for ML training._

In [None]:
!time ./s3tar \
--size-limit 260000000 \
--region {aws_region} \
--concat-in-memory \
--storage-class STANDARD \
--goroutines 1000 \
-cvf s3://{s3_bucket_name}/{s3_bucket_prefix}/100k-samples-large-files/shard.tar \
s3://{s3_bucket_name}/{s3_bucket_prefix}/100k-samples-small-files/

## 5.2 The impact of data sharding on ML training efficiency

When training machine learning models with data stored in S3, the way you organize your dataset can dramatically impact your training performance. Let's explore why sharded datasets offer superior throughput performance than datasets comprised of lots of small files and thus can help you to alleviate any potential I/O bottlenecks in your ML training pipelines.

Let's take our computer vision dataset as an example, and consider two approaches to storing it in S3:

1. **Individual files dataset** (one sample per file): Storing each sample as a separate 100KB JPG file
2. **Sharded dataset** (multiple samples per file): Combining multiple image samples into larger 250MB shards


<img src="assets/pic_sequential_vs_random.png" width="1200"/>

In the non-sharded approach, individual training samples (100KB JPG files) are stored as separate objects in S3 storage. While this approach is straightforward, each sample retrieval requries a separate GET request to S3, incurring a Time-To-First-Byte (TTFB) latency which is typically ~50-200ms for S3 general purpose buckets. As can be seen on the illustration above, reading lots of small files from S3 results in significant cumulative overhead due to TTFB latencies and can lead to suboptimal throughput.
> 💡 _**TIP**: The **S3 Express One Zone** storage class and the associated [S3 directory buckets](https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-express-one-zone.html) are designed for workloads or performance-critical applications that require consistent single-digit millisecond latency. While exploring this option is outside of the scope for this workshop, consider this bucket type when you require the lowest latency object storage in the cloud._

Conversely, sharding our dataset by aggregating multiple training samples into consolidated objects (such as TAR-files), and then reading the training sample sequentially one after another from the data shards, we drastically reduce the number of GET requests to S3 in order to consume the dataset. This amortizes TTFB latency across multiple samples, substantially increasing the effective throughput.

This makes sharding particularly crucial for ML training workflows, where you need to process thousands or millions of samples efficiently. Sharded data enables higher throughput, which can correspond to higher utilization of GPU resourecs.

## 5.3 Inspect the sharded dataset
Let's quickly inspect the sharded dataset that you have just created on S3. As you have already mounted your S3 bucket locally, you can use regular Linux commands as if the dataset were locally available:

In [None]:
!du -ha {local_mountpoint_dir}/{s3_bucket_prefix}/100k-samples-large-files

# 6. Re-run the dataloading benchmark with sharded dataset

<a id='sec-6'></a>

You will now compare performance between the sharded and original datasets, by executing the ML training job on the Ray cluster using your newly sharded dataset. To maintain experimental consistency, all benchmarking parameters remain identical to the previous run, with the only changes being the `dataset_path` and `dataset_format` parameters to accommodate the sharded dataset.

In [None]:
### --------
### STEP #1: Compose the entrypoint command for Ray workers
### -------

# Set dataset name and format that we want to use in benchmark
dataset_name = '100k-samples-large-files'
dataset_format = 'tar'
dataset_path = os.path.join(eks_mountpoint_dir, s3_bucket_prefix, dataset_name)
benchmark_name = f'benchmark-dataloading-{dataset_name}-{dataset_format}-{datetime.now().strftime("%Y%m%d%H%M%S-%f")}'

# Compose entrypoint command string
entrypoint_command = "python benchmark.py" \
                     "  --epochs=3" \
                     "  --batch_size=64" \
                     "  --prefetch_size=2" \
                     "  --input_dim=224" \
                     "  --dataloader_workers=16" \
                    f"  --dataset_path={dataset_path}" \
                    f"  --dataset_format={dataset_format}" \
                     "  --model_compute_time=0" \
                     "  --ray_workers=2" \
                     "  --ray_cpus_per_worker=8" \
                    f"  --benchmark_name={benchmark_name}"


### --------
### STEP #2: Define the runtime environment parameters for Ray workers
### -------

# Nothing to do! We use the same runtime environment definition as for the first benchmark.

### --------
### STEP #3: Submit job to Ray cluster
### -------

job_id = ray_client.submit_job(entrypoint=entrypoint_command, runtime_env=runtime_environment)


### Print out the Ray Job ID and other details
print(f"Submitted a new Ray job with ID '{job_id}' and the following entrypoint command: \n")
for line in entrypoint_command.split('  '):
    print(line)

## 6.1 Analyze and plot results (dataloading with sharded dataset)

Now plot the results of our second benchmark with sharded dataset.

> ⚠️ _**NOTE:** The Ray job that you have just launched will take <font color='red'>**around 2 minutes**</font> to complete. The cell below will automatically wait for the job completion and plot the benchmark results. **Feel free to return to your Ray and Grafana dashboard browser tabs to monitor job progress.**_

In [None]:
# Links to the Ray and Grafana dashboards
print(f"Ray Dashboard: http://{os.getenv('RAY_DASHBOARD_NLB_DNS')}")
print(f"Grafana Dashboard: http://{os.getenv('GRAFANA_NLB_DNS')}")
print('-' * 60)

# Wait for the job to finish, so that we can plot the results
utils.wait_for_job_to_finish(job_id, ray_client)

# Load the log file and plot the results
logfile_path = os.path.join(local_mountpoint_dir, 'logs', benchmark_name + '.json')
print(f"Plotting results from '{logfile_path}'..")
utils.plot_dataloading_results(logfile_path)

### Explanation

In this benchmark, you analyzed the impact of caching on training performance when using a sharded dataset. The red line shows actual epoch times with caching enabled, while the blue line represents a hypothetical scenario where data is streamed from S3 at each epoch without caching. The effect of caching on epoch time is significantly reduced (if not negligible), compared with the previous benchmark.

This improvement is due to the advantages of streaming sharded data from S3, which we have discussed in the previous section. Looking at the red line, even the first epoch — where data is streamed directly from S3 before any caching has occurred — experiences minimal I/O bottleneck (as it is just as fast as the subsequent epochs). This indicates that sharding has effectively addressed the latency issues seen with the non-sharded dataset, achieving near-optimal data transfer rates right from the training outset.

In summary, with sharding, there is little difference in performance between cached and uncached scenarios, as the S3 streaming throughput is sufficient to saturate the available compute resources in both cases. This finding suggests that while dataset caching remains an effective strategy for maximizing data pipeline throughput, sharding can also be a powerful tool for large datasets, especially those that exceed the storage capacity of local instances.

# 7. Checkpoint models with **Mountpoint for Amazon S3**

<a id='sec-7'></a>

Up to now we were only concerned with topic of efficient dataloading for ML training, which is essentially about how to _**get training data from S3 into the ML training cluster**_. Let us now turn our attention to the topic of model checkpointing, and see how we can _**get the model data out from the ML training cluster to S3**_ in the most efficient manner.

## 7.1 The critical role of model checkpointing in large-scale training

Model checkpointing is a fundamental mechanism in machine learning workflows that periodically saves the complete training state, including model weights, optimizer states, and other training parameters. In large-scale distributed training environments, where computations run across hundreds or thousands of nodes for days or weeks, checkpointing becomes crucial for fault tolerance and experiment reproducibility. Without efficient checkpointing, a single node failure could result in the loss of days of training progress, necessitating a complete restart. However, traditional checkpointing mechanisms often create a significant performance bottleneck, as when saving state, all nodes must (typically) pause their training until the checkpoint operation completes, directly impacting training time and resource utilization. This challenge is particularly acute in modern deep learning models with billions of parameters, where checkpoint sizes can reach hundreds of gigabytes. Additionally, one needs to implement auxilliary background syncing mechanisms between local storage and Amazon S3 for persistently storing model states. Therefore, an efficient checkpointing solution that minimizes training interruption, while ensuring reliable state preservation, is essential for production-scale machine learning operations.

<img src="assets/pic_ckpting_to_local_vs_s3_storage.png" width="1200"/>


## 7.2 Using **Mountpoint for S3** for high-performance model checkpointing

[**Mountpoint for Amazon S3**](https://github.com/awslabs/mountpoint-s3) streamlines ML checkpointing directly to Amazon S3. This eliminates the traditional two-step process of first saving model snapshots to local storage and then uploading (or syncing) them to cloud storage for persistence, which adds to both I/O overhead and operational complexity. Mountpoint for S3 leverages [**AWS Common Runtime**](https://aws.amazon.com/blogs/storage/improving-amazon-s3-throughput-for-the-aws-cli-and-boto3-with-the-aws-common-runtime/) to distribute large file writes elastically across the S3 fleet, resulting in up to 60% faster model checkpointing performance than saving model snapshots to local NVMe instance storage. This superior performance is possible because, instead of being bottlenecked by the bandwidth of a single local disk, Mountpoint for Amazon S3 can parallelize the object uploads across the Amazon S3 fleet, and burst to hundreds of gigabits per second during the checkpointing process.

<img src="assets/pic_ckpting_with_crt.png" width="1200"/>

Beyond performance, the low cost of Amazon S3, particularly when using the [Intelligent Tiering](https://aws.amazon.com/s3/storage-classes/intelligent-tiering/) storage class, make it ideal for storing your model checkpoints for the long term. And you can quickly and easily return to earlier experiments. By leveraging Mountpoint for S3, no code changes are required to your training scripts.

## 7.3 Running model checkpointing benchmarks on Ray cluster


_It's benchmarking time again!_ To quantitatively evaluate the checkpointing performance directly to S3 using **Mountpoint for Amazon S3**, let's now run a comparative benchmark against checkpointing to local storage of the Ray workers (which is the attached EBS gp3 volume in our case). The benchmarks that we are about to run will utilize our previous benchmarking script, with a few additional configuration parameters to control checkpointing behavior:

- `ckpt_steps` - defines number of steps between checkpoints (set to `100` in this benchmark, but setting to `0` will disable checkpointing);
- `ckpt_mode` -  checkpointing backend, either `disk` (for checkpointing to local path), or `s3pt` (to use S3 Connector for PyTorch);
- `ckpt_path` - storage path (S3 URI or local filesystem path);
- `model_num_parameters` - model size in millions of parameters, which effectively defines the model snapshot size.

> ⚠️ _**FEW NOTES:**_
> - _As the we set `epochs=1` and `ckpt_steps=100` below, we will checkpoint exactly 7 times during our benchmark job and report the **average checkpointing time** (this is because each Ray worker processes ~750 batches per epoch, assuming 100k sample dataset, `batch_size=64` and `ray_workers=2`);_
> - _As the we set `model_num_parameters=1000`, the resulting model snapshots will be approx. **4GB** in size. This is because we are saving 1000M weights in `fp32` format (i.e. allocating 4 bytes per model weight)._

In [None]:
ckpt_benchmarks = {}

for ckpt_destination in ('mountpoint', 'local_disk'):

    ### --------
    ### STEP #1: Compose the entrypoint command for Ray workers
    ### -------

    # Set dataset name, dataset format, and benchmark name
    dataset_name = '100k-samples-large-files'
    dataset_format = 'tar'
    dataset_path = os.path.join(eks_mountpoint_dir, s3_bucket_prefix, dataset_name)
    benchmark_name = f'benchmark-checkpointing-{ckpt_destination}-{datetime.now().strftime("%Y%m%d%H%M%S-%f")}'

    # Set checkpoint path on Ray cluster: either S3 mount point path, or local volume path
    ckpt_path = os.path.join(eks_mountpoint_dir, 'checkpoints') if ckpt_destination == 'mountpoint' else 'checkpoints/'
    
    # Compose entrypoint command string
    entrypoint_command = "python benchmark.py" \
                         "  --epochs=1" \
                         "  --batch_size=64" \
                         "  --prefetch_size=2" \
                         "  --input_dim=224" \
                         "  --dataloader_workers=16" \
                        f"  --dataset_path={dataset_path}" \
                        f"  --dataset_format={dataset_format}" \
                         "  --model_compute_time=0" \
                         "  --ray_workers=2" \
                         "  --ray_cpus_per_worker=8" \
                        f"  --benchmark_name={benchmark_name}" \
                         "  --model_num_parameters=1000" \
                         "  --ckpt_steps=100" \
                         "  --ckpt_mode=disk" \
                        f"  --ckpt_path={ckpt_path}"
    
    
    ### --------
    ### STEP #2: Define the runtime environment parameters for Ray workers
    ### -------
    
    # Nothing to do! We use the same runtime environment definition as for the first benchmark.
    
    ### --------
    ### STEP #3: Submit job to Ray cluster
    ### -------
    
    job_id = ray_client.submit_job(entrypoint=entrypoint_command, runtime_env=runtime_environment)
    
    
    ### Print out the Ray Job ID and other details
    print(f"Submitted a new Ray job with ID '{job_id}' and the following entrypoint command: \n")
    for line in entrypoint_command.split('  '):
        print(line)
    print('^'*40, '\n')
    time.sleep(5)

    # Keep track of our benchmarks
    ckpt_benchmarks[job_id] = {'name': benchmark_name, 'tag': ckpt_destination}

### Go to the Ray dashboard and Grafana to observe the TWO jobs running, and complete in Succeed state

In [None]:
print(f"Grafana Dashboard: http://{os.getenv('GRAFANA_NLB_DNS')}")
print(f"Ray Dashboard: http://{os.getenv('RAY_DASHBOARD_NLB_DNS')}")

## 7.4 Analyze and plot results (model checkpointing)

Now plot the results of your checkpointing benchmarks.

> ⚠️ _**NOTE:** The **two** Ray jobs that we have just submnitted will run in parallel, and the longest job will take <font color='red'>**around 4 minutes**</font> to complete. The cell below will automatically wait for the job to complete, and then plot the benchmark results._

In [None]:
# Wait for the job to finish, so that we can plot the results
for job_id in ckpt_benchmarks:
    utils.wait_for_job_to_finish(job_id, ray_client)

# Load the log files and plot the results
benchmark_files = {
    benchmark['tag'].replace('_', ' ').title(): os.path.join(local_mountpoint_dir, 'logs', benchmark['name'] + '.json')
    for benchmark in ckpt_benchmarks.values()
}

print(f"Plotting results for '{json.dumps(benchmark_files, indent=2)}'..")
utils.plot_checkpointing_results(benchmark_files)

### Explanation

The benchmark results above illustrate the time required to save periodic model checkpoints, either directly to S3 using Mountpoint for S3 (in red) or to a 'local' EBS volume (in blue). The results indicate a clear advantage when saving checkpoints directly to S3, as checkpointing times are consistently lower compared to saving to local storage.

This higher throughput, close to the maximum network bandwidth of the EC2 instance, provides significant time and cost savings during training. Particularly for long-running distributed training jobs with frequent checkpointing requirements, reducing overall training overhead and allowing more efficient resource utilization.

>💡 _**TIP:** In this workshop, we have used EBS storage. If you use EC2 instances with instance store you will also have one or more physically attached ephemeral volumes. Instance store is ideal for temporary storage of information that changes frequently, such as buffers, caches, scratch data, and other temporary content._

<br>

# 8. Summary

<a id='sec-8'></a>

This workshop demonstrated practical techniques for optimizing data I/O for AI/ML workloads on AWS using Mountpoint for S3. Through hands-on experiments and benchmarks, you explored key aspects of ML infrastructure optimization. Let's recap what you have done in the workshop and summarize the identified best practices.

#### Shared storage
- After learning about the dataset on your local storage, you used the _AWS CLI_ and _s5cmd_, an open source tool, to migrate this dataset to Amazon S3;
- You installed Mountpoint for Amazon S3, and learned how you can use this interact with data in S3 as if it were a local file system.

#### Infrastructure integration
- You used the integration between Ray clusters on Amazon EKS and S3 storage using the [Mountpoint for Amazon S3 Container Storage Interface (CSI) Driver](https://github.com/awslabs/mountpoint-s3-csi-driver), which allows Kubernetes applications to access Amazon S3 objects through a file system interface;
- You also learned how to scale ML workloads across multiple compute nodes with Ray while maintaining efficient data access.

#### Efficient data loading
- We compared two approaches for organizing training data in S3: individual files vs sharded datasets, demonstrating that sharding samples into larger files (using, e.g., TAR- or TFRecord-formats) significantly improved data loading performance;
- The improvement stems from reduced number of S3 API calls and associated TTFB latency overhead when reading dataset samples sequentially from S3;
- In cases when sharding is not possible, we demostrated that local caching is another robust strategy to incorporate, especially when running multiple epochs.

#### Model checkpointing
- Model checkpointing directly to S3 using Mountpoint for Amazon S3 is potentially significantly faster than checkpointing to local storage;
- This efficiency comes from Mountpoint for Amazon S3's ability to distribute writes across the S3 fleet;
- Additionally, this approach eliminates the traditional two-step process of saving to local storage followed by cloud storage syncing.



# 9. Next steps

<a id='sec-9'></a>

- If you have time now, and wish to learn about the **Amazon S3 Connector for PyTorch**, <font color='blue'>**_please proceed to the Bonus Notebook_**</font>  by opening `2_bonus_notebook.ipynb` in the JupyterLab file browser panel (on the left).
S3 Connector for PyTorch provides an even tighter integration between ML raining workloads with PyTorch and the S3 storage service, offering both high-performance dataloading and model checkpointing.

- If you wish to share this workshop with colleagues, and/or run it in your own time in your own AWS Account, make a note of this link: https://s12d.com/stg406

- If you are at an AWS event, **please fill out the session survey provided by AWS staff**. Your feedback helps us improve, and justifies our efforts in creating content such as this. Thank you.
