# Hyperparameter Tuning using SageMaker PyTorch Container

This notebook demonstrates how to use [SageMaker Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html) (hyperparameter optimization) to find the best hyperparameters for training a PyTorch MNIST model.

### SageMaker Python SDK V3 APIs used

| Package | Class / Utility | Purpose |
|---------|----------------|----------|
| `sagemaker-train` | `ModelTrainer` | Unified training interface (replaces V2 `PyTorch` Estimator) |
| `sagemaker-train` | `HyperparameterTuner` | Automatic hyperparameter optimization |
| `sagemaker-train` | `SourceCode`, `InputData`, `Compute` | Structured training configuration |
| `sagemaker-core` | `Model`, `EndpointConfig`, `Endpoint` | Low-level resource management for deployment |
| `sagemaker-core` | `repack_model()` | Inject inference code into model artifacts |
| `sagemaker-core` | `image_uris.retrieve()` | Retrieve SageMaker Deep Learning Container URIs |
| `sagemaker-core` | `Session`, `get_execution_role` | Session management and IAM role resolution |

<div style="border: 2px solid #ff9900; border-radius: 8px; padding: 15px; background-color: #fff3e0; margin-bottom: 10px;">
<strong>⚠️ Compatibility Notice:</strong> This notebook has been tested using <strong>SageMaker Distribution Image 3.7.0</strong> and the <strong>SageMaker Python SDK version 3.4.0</strong>.
</div>

In [None]:
!pip install -q -U "sagemaker==3.4.0"

## Contents

1. [Background](#Background)
1. [Setup](#Setup)
1. [Data](#Data)
1. [Train](#Train)
1. [Host](#Host)

---

## Background

MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial shows how to train and test an MNIST model on SageMaker using PyTorch, and how to use SageMaker Automatic Model Tuning to find the best hyperparameters.

---

## Setup

We start by creating a `Session` (from `sagemaker.core`) and resolving the IAM execution role. The session manages interactions with S3 and the SageMaker service.

- `Session()` — initializes the SageMaker session with default bucket and region
- `get_execution_role()` — retrieves the IAM role ARN for training and hosting

In [None]:
from importlib.metadata import version as pkg_version
print(f"SageMaker SDK version: {pkg_version('sagemaker')}")

from sagemaker.core.helper.session_helper import Session, get_execution_role
from sagemaker.core import image_uris

sagemaker_session = Session()
region = sagemaker_session.boto_region_name

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-pytorch-mnist"
default_bucket_prefix = sagemaker_session.default_bucket_prefix

# If a default bucket prefix is specified, append it to the s3 path
if default_bucket_prefix:
    prefix = f"{default_bucket_prefix}/{prefix}"

role = get_execution_role()

## Data
### Getting the data

In [None]:
import os
import gzip
import urllib.request

local_dir = "data"
base_url = f"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/"
files = [
    "t10k-images-idx3-ubyte.gz",
    "t10k-labels-idx1-ubyte.gz",
    "train-images-idx3-ubyte.gz",
    "train-labels-idx1-ubyte.gz",
]
os.makedirs("data/MNIST/raw", exist_ok=True)
for f in files:
    gz_path = f"data/MNIST/raw/{f}"
    raw_path = gz_path.replace(".gz", "")
    if not os.path.exists(gz_path):
        print(f"Downloading {f}...")
        urllib.request.urlretrieve(base_url + f, gz_path)
    # Extract .gz files (torchvision in the container expects uncompressed files)
    if not os.path.exists(raw_path):
        with gzip.open(gz_path, "rb") as f_in:
            with open(raw_path, "wb") as f_out:
                f_out.write(f_in.read())
print("MNIST data ready.")

### Uploading the data to S3
We use `Session.upload_data()` to upload the dataset to S3. The returned S3 URI is passed as an `InputData` source to the training job.

In [None]:
inputs = sagemaker_session.upload_data(path="data", bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))

## Train
### Training script
The `mnist.py` script provides all the code we need for training and hosting a SageMaker model (`model_fn` function to load a model).
The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:

* `SM_MODEL_DIR`: A string representing the path to the directory to write model artifacts to. These artifacts are uploaded to S3 for model hosting.
* `SM_NUM_GPUS`: The number of GPUs available in the current container.
* `SM_CHANNEL_TRAINING`: A string representing the path to the directory containing data in the 'training' channel.

For more information about training environment variables, please visit [SageMaker Containers](https://github.com/aws/sagemaker-containers).

In [None]:
!pygmentize mnist.py

### Set up hyperparameter tuning job
*Note, with the default setting below, the hyperparameter tuning job can take about 20 minutes to complete.*

We use the V3 `HyperparameterTuner` (from `sagemaker.train.tuner`) together with `ModelTrainer` to run automatic model tuning. The steps are:

1. Create a `ModelTrainer` with `SourceCode`, `Compute`, and the training image URI
2. Define hyperparameter ranges using `ContinuousParameter`, `CategoricalParameter`, or `IntegerParameter` (from `sagemaker.core.parameter`)
3. Define the objective metric and its regex pattern
4. Create a `HyperparameterTuner` with the trainer, ranges, metric, and resource budget

We define a `ModelTrainer` (from `sagemaker.train`) passing in:
- `training_image` — the container URI retrieved via `image_uris.retrieve()`
- `source_code` — a `SourceCode` config pointing to the local directory and entry script
- `compute` — a `Compute` config specifying instance type and count
- `role` — the IAM execution role
- `hyperparameters` — static hyperparameters (not tuned)

In [None]:
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import SourceCode, InputData, Compute

training_image = image_uris.retrieve(
    framework="pytorch",
    region=region,
    version="2.0.1",
    py_version="py310",
    instance_type="ml.c5.2xlarge",
    image_scope="training",
)

trainer = ModelTrainer(
    training_image=training_image,
    source_code=SourceCode(source_dir=".", entry_script="mnist.py"),
    compute=Compute(instance_type="ml.c5.2xlarge", instance_count=1),
    role=role,
    hyperparameters={"epochs": "1", "backend": "gloo"},
    base_job_name="pytorch-mnist-hpo",
)

In [None]:
# test training job

trainer.train(
    input_data_config=[InputData(channel_name="training", data_source=inputs)],
    wait=True,
    logs=True,
)

Once we've defined our `ModelTrainer` we can specify the hyperparameters we'd like to tune and their possible values. We have three different types of hyperparameters.
- Categorical parameters need to take one value from a discrete set. We define this by passing the list of possible values to `CategoricalParameter(list)`
- Continuous parameters can take any real number value between the minimum and maximum value, defined by `ContinuousParameter(min, max)`
- Integer parameters can take any integer value between the minimum and maximum value, defined by `IntegerParameter(min, max)`

*Note, if possible, it's almost always best to specify a value as the least restrictive type. For example, tuning learning rate as a continuous value between 0.01 and 0.2 is likely to yield a better result than tuning as a categorical parameter with values 0.01, 0.1, 0.15, or 0.2.*

In [None]:
from sagemaker.core.parameter import ContinuousParameter, CategoricalParameter

hyperparameter_ranges = {
    "lr": ContinuousParameter(0.001, 0.1),
    "batch-size": CategoricalParameter([32, 64, 128, 256, 512]),
}

Next we specify the objective metric that we'd like to tune and its definition, which includes the regular expression (Regex) needed to extract that metric from the CloudWatch logs of the training job. In this case, our script emits average loss value and we will use it as the objective metric. We set `objective_type` to `'Minimize'`, so that hyperparameter tuning seeks to minimize the objective metric when searching for the best hyperparameter setting.

In [None]:
objective_metric_name = "average test loss"
objective_type = "Minimize"
metric_definitions = [{"Name": "average test loss", "Regex": "Test set: Average loss: ([0-9\\.]+)"}]

Now, we create a `HyperparameterTuner` (from `sagemaker.train.tuner`), passing in:
- The `ModelTrainer` instance
- Hyperparameter ranges
- Objective metric name and definition
- Tuning resource configurations: `max_jobs` (total training jobs) and `max_parallel_jobs` (concurrent jobs)

In [None]:
from sagemaker.train.tuner import HyperparameterTuner

# Create a fresh ModelTrainer for the tuner
tuner_trainer = ModelTrainer(
    training_image=training_image,
    source_code=SourceCode(source_dir=".", entry_script="mnist.py"),
    compute=Compute(instance_type="ml.c5.2xlarge", instance_count=1),
    role=role,
    hyperparameters={"epochs": "1", "backend": "gloo"},
    base_job_name="pytorch-mnist-hpo",
)

tuner = HyperparameterTuner(
    tuner_trainer,
    objective_metric_name,
    hyperparameter_ranges,
    metric_definitions,
    max_jobs=9,
    max_parallel_jobs=3,
    objective_type=objective_type,
)

### Launch hyperparameter tuning job
We start the tuning job by calling `tuner.tune()` with an `InputData` config pointing to our S3 training data. Setting `wait=True` blocks until all training jobs complete.

You can monitor progress in the SageMaker console under **Hyperparameter tuning jobs**.

In [None]:
tuner.tune(
    inputs=[InputData(channel_name="training", data_source=inputs)],
    wait=True,
)

## Host
### Create endpoint
After tuning completes, we deploy the best model to a real-time endpoint. The steps are:

1. **Get best training job** — `tuner.best_training_job()` returns the name of the job with the best objective metric
2. **Retrieve model artifacts** — `TrainingJob.get()` fetches the job metadata, including the S3 URI of `model.tar.gz`
3. **Repack model** — `repack_model()` (from `sagemaker.core.utils`) injects the inference script into the model archive. The PyTorch serving container expects a `code/` directory inside `model.tar.gz` containing the entry point script
4. **Create resources** — Use `sagemaker-core` resource classes (`Model.create()`, `EndpointConfig.create()`, `Endpoint`) to deploy

The `repack_model()` utility handles the full download → extract → inject → re-tar → upload workflow in a single call, replacing what would otherwise be ~15 lines of manual S3/tarfile code.

In [None]:
import time
import boto3
from sagemaker.core.resources import TrainingJob, Model, EndpointConfig, Endpoint
from sagemaker.core.shapes import ContainerDefinition, ProductionVariant
from sagemaker.core.utils import repack_model

sm_client = boto3.client("sagemaker")
sm_runtime = boto3.client("sagemaker-runtime")

# Get best training job
best_job_name = tuner.best_training_job()
print(f"Best training job: {best_job_name}")

best_training_job = TrainingJob.get(training_job_name=best_job_name)
model_data = best_training_job.model_artifacts.s3_model_artifacts
print(f"Best model artifacts: {model_data}")

inference_image = image_uris.retrieve(
    framework="pytorch", region=region, version="2.0.1",
    py_version="py310", instance_type="ml.m5.xlarge", image_scope="inference",
)

# Repack model with inference code using V3 repack_model utility
repackaged_s3 = f"s3://{bucket}/{prefix}/repackaged-model/model.tar.gz"
repack_model(
    inference_script="mnist.py",
    source_directory=".",
    dependencies=[],
    model_uri=model_data,
    repacked_model_uri=repackaged_s3,
    sagemaker_session=sagemaker_session,
)
print(f"Repacked model: {repackaged_s3}")

timestamp = time.strftime("%Y%m%d-%H%M%S")
model_name = f"pytorch-mnist-hpo-model-{timestamp}"
endpoint_config_name = f"pytorch-mnist-hpo-epc-{timestamp}"
endpoint_name = f"pytorch-mnist-hpo-ep-{timestamp}"

sm_model = Model.create(
    model_name=model_name,
    primary_container=ContainerDefinition(
        image=inference_image,
        model_data_url=repackaged_s3,
        environment={"SAGEMAKER_PROGRAM": "mnist.py"},
    ),
    execution_role_arn=role,
)
print(f"Model created: {model_name}")

endpoint_config = EndpointConfig.create(
    endpoint_config_name=endpoint_config_name,
    production_variants=[
        ProductionVariant(
            variant_name="AllTraffic",
            model_name=model_name,
            instance_type="ml.m5.xlarge",
            initial_instance_count=1,
        )
    ],
)
print(f"EndpointConfig created: {endpoint_config_name}")

sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
print(f"Creating endpoint: {endpoint_name} ...")
waiter = sm_client.get_waiter("endpoint_in_service")
waiter.wait(EndpointName=endpoint_name, WaiterConfig={"Delay": 30, "MaxAttempts": 60})
print("Endpoint in service!")

### Evaluate
We can now use this endpoint to classify hand-written digits. We load test images, serialize them with `NumpySerializer` (from `sagemaker.core.serializers`), and invoke the endpoint via `boto3 invoke_endpoint()`.

In [None]:
import numpy as np
import random
import io

data_dir = "data/MNIST/raw"
with gzip.open(os.path.join(data_dir, "t10k-images-idx3-ubyte.gz"), "rb") as f:
    images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)

mask = random.sample(range(len(images)), 16)  # randomly select some of the test images
mask = np.array(mask, dtype=np.intp)
data = images[mask]

In [None]:
from sagemaker.core.serializers import NumpySerializer

serializer = NumpySerializer()
payload = serializer.serialize(np.expand_dims(data, axis=1))

response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/x-npy",
    Accept="application/x-npy",
    Body=payload,
)
result = np.load(io.BytesIO(response["Body"].read()), allow_pickle=True)

print("Raw prediction result:")
print(result)
print()

labeled_predictions = list(zip(range(10), result[0]))
print("Labeled predictions: ")
print(labeled_predictions)
print()

labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print("Most likely answer: {}".format(labeled_predictions[0]))

### Cleanup

Delete the endpoint, endpoint configuration, and model to release resources and stop incurring charges.

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=model_name)
print("Cleanup complete.")