# Train a Pytorch model with a SageMaker Training Job


Script mode allows you to build models using a custom algorithm not supported by one of the built-in choices. This is referred to as script mode because you write your custom code (script) in a text file with a .py extension.

SageMaker supports most of the popular ML frameworks through pre-built containers, and has taken the extra step to optimize them to work especially well on AWS compute and network infrastructure in order to achieve near-linear scaling efficiency. These pre-built containers also provide some additional Python packages, such as Pandas and NumPy, so you can write your own code for training an algorithm. These frameworks also allow you to install any Python package hosted on PyPi by including a requirements.txt file with your training code or to include your own code directories.

In this example, we will train a PyTorch MNIST model using a SageMaker Training Job.


**Note**: Select the image `PyTorch 1.12 Python 3.8 CPU Optimized image` and instance type `ml.t3.medium`.


## Setup

In [None]:
from sagemaker.session import Session
from sagemaker import get_execution_role

role = get_execution_role()
sm_session = Session()
region = sm_session.boto_session.region_name
default_bucket = sm_session.default_bucket()
print(default_bucket)


We will make the MNIST data accessible to our training job by uploading it to S3.

In [None]:
base_prefix = "mnist_training_job_examples"
data_prefix = "data"
data_s3_uri = sm_session.upload_data(path="mnist_data", bucket=default_bucket, key_prefix=f"{base_prefix}/{data_prefix}")
print(data_s3_uri) # S3 location of the data

Let's confirm that the data is in S3 using the AWS CLI

In [None]:
! aws s3 ls {data_s3_uri}/

## Create model training script
Let's create `mnist.py`, the pytorch script file to train our model.


In [None]:
!mkdir -p script

In [None]:
%%writefile ./script/mnist.py

import argparse
import json
import logging
import sys
import time
import os
from os.path import join
import boto3
import torch
from torchvision import datasets, transforms

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, kernel_size, drop_out):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, hidden_channels, kernel_size=kernel_size)
        self.conv2 = torch.nn.Conv2d(hidden_channels, 20, kernel_size=kernel_size)
        self.conv2_drop = torch.nn.Dropout2d(p=drop_out)
        self.fc1 = torch.nn.Linear(320, 50)
        self.fc2 = torch.nn.Linear(50, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))
        x = torch.nn.functional.relu(
            torch.nn.functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2)
        )
        x = x.view(-1, 320)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.nn.functional.log_softmax(x, dim=1)


def log_performance(model, data_loader, device, epoch, metric_type="Test"):
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += torch.nn.functional.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    loss /= len(data_loader.dataset)
    accuracy = 100.0 * correct / len(data_loader.dataset)
    # log metrics
    logger.info(
        "{} Average loss: {:.4f}, {} Accuracy: {:.4f}%;\n".format(
            metric_type, loss, metric_type, accuracy
        )
    )


def train_model(
    train_set, test_set, optimizer="sgd", epochs=10, hidden_channels=10
):
    """
    Function that trains the CNN classifier to identify the MNIST digits.
    Args:
        train_set (torchvision.datasets.mnist.MNIST): train dataset
        test_set (torchvision.datasets.mnist.MNIST): test dataset
        optimizer (str): the optimization algorthm to use for training your CNN
                         available options are sgd and adam
        epochs (int): number of complete pass of the training dataset through the algorithm
        hidden_channels (int): number of hidden channels in your model
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the seed for generating random numbers
    torch.manual_seed(42)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=True)
    logger.info(
        "Processes {}/{} ({:.0f}%) of train data".format(
            len(train_loader.sampler),
            len(train_loader.dataset),
            100.0 * len(train_loader.sampler) / len(train_loader.dataset),
        )
    )

    logger.info(
        "Processes {}/{} ({:.0f}%) of test data".format(
            len(test_loader.sampler),
            len(test_loader.dataset),
            100.0 * len(test_loader.sampler) / len(test_loader.dataset),
        )
    )
    model = Net(hidden_channels, kernel_size=5, drop_out=0.5).to(device)
    model = torch.nn.DataParallel(model)
    momentum = 0.5
    lr = 0.01
    log_interval = 100
    if optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs + 1):
        print("Training Epoch:", epoch)
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader, 1):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = torch.nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                logger.info(
                    "Train Epoch: {} [{}/{} ({:.0f}%)], Train Loss: {:.6f};".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.sampler),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )
        log_performance(model, train_loader, device, epoch, "Train")
        log_performance(model, test_loader, device, epoch, "Test")
    return model


def save_model(model, model_dir):
    logger.info("Saving the model.")
    path = os.path.join(model_dir, "model.pth")
    # recommended way from http://pytorch.org/docs/master/notes/serialization.html
    torch.save(model.cpu().state_dict(), path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Hyperparameters
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 10)",
    )
    parser.add_argument("--optimizer", type=str, default="sgd", help="optimizer for training.")
    parser.add_argument(
        "--hidden_channels",
        type=int,
        default=10,
        help="number of channels in hidden conv layer",
    )

    # Container environment
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
    parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"])

    args = parser.parse_args()

    train_set = datasets.MNIST(
        args.data_dir,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=False,
    )

    test_set = datasets.MNIST(
        args.data_dir,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=False,
    )
    
    model = train_model(
        train_set,
        test_set,
        optimizer=args.optimizer,
        epochs=args.epochs,
        hidden_channels=args.hidden_channels,
        )
    save_model(model, args.model_dir)

As you can see above, the training script is very similar to a training script you might run outside of SageMaker. During the training job, you can access useful properties about the training environment through various environment variables, such as:

* `SM_MODEL_DIR`: A string representing the container directory `/opt/ml/model`. This is the directory where you can save your model outputs during training. When your training job completes, SageMaker will package the directory contents into a compressed `tar` archive and save it to a sepecified S3 location.

* `SM_CHANNEL_[channel_name]`: A string representing the path to the directory containing the channel data. In the script below,`SM_CHANNEL_TRAINING` represents the directory `/opt/ml/input/data/training`, which contains data for the `"training"` channel.

**Note:** It is also possible to specify additional channels, such as `"testing"`. This would create the environment variable `SM_CHANNEL_TESTING` which points to `/opt/ml/input/data/testing`.

## Start the Training Job

First, create a [PyTorch estimator](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/sagemaker.pytorch.html) with the following:

* `source_dir`: Path to a local directory that contains the training script. This directory can also contain other Python modules the script depends as well as a `requirements.txt` file for specifying additional dependencies to install. SageMaker will copy all of the files under this directory to the container directory `/opt/ml/code`.

* `entry_point`: The training script which SageMaker will execute. This file should exist in `source_dir`.

* `role`: The ARN for the [SageMaker execution role](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html).

* `output_path`: A S3 path to save model artifacts to.

* `framework_version`: PyTorch version you want to use for executing your model training code.

* `py_version`: Python version you want to use for executing your model training code.

* `instance_type`: The instance type to use for training.

* `instance_count`: Number of instances you want to use for training.

* `hyperparameters`: The hyperparameters which will be used for training. Hyperparameters are passed to your script as arguments and can be retrieved with an `argparse.ArgumentParser` instance.

* `input_mode`: How the training data is made available to the training container. When this is set as "File" (the default), SageMaker will download all of the data from S3 the EBS volume before training. For the complete set of options for making your data available, please visit the documentation for accessing training data [here](https://docs.aws.amazon.com/sagemaker/latest/dg/model-access-training-data.html).

* `volume_size`: The size of the EBS volume for the training job. This should be large enough to hold the training data which will be downloaded from S3.

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    source_dir="script",
    entry_point="mnist.py",
    role=role,
    output_path=f"s3://{default_bucket}/{base_prefix}/models",
    framework_version="1.12",
    py_version="py38",
    instance_type="ml.c5.xlarge",
    instance_count=1,
    hyperparameters={"epochs": 10, "hidden_channels": 5, "optimizer": "adam"},
    input_mode="File",
    volume_size=10
)

Using the `fit()` method, we supply one input channel, `"training"`, and map it to our data in S3. Since the `input_mode` is set to `"File"`, SageMaker will download all of the data to `/opt/ml/input/data/training` before training begins.

In [None]:
estimator.fit({"training": data_s3_uri}, job_name="my-first-training-job")

Save variables for the next notebook exercise `3_pytorch_training_job_experiment.ipynb`.

In [None]:
%store base_prefix
%store data_s3_uri