# Training a Torch Classifier

This tutorial shows you how to train an image classifier using the [Ray AI Runtime](air) (AIR).

You should be familiar with [PyTorch](https://pytorch.org/) before starting the tutorial. If you need a refresher, read PyTorch's [training a classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) tutorial.

## Before you begin

* Install the [Ray AI Runtime](air). You need Ray 2.0 or later to run this example.

In [1]:
!pip install 'ray[air]'



* Install `requests`, `torch`, and `torchvision`

In [2]:
!pip install requests torch torchvision



## Load and normalize CIFAR-10

We'll train our classifier on a popular image dataset called [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html).

First, let's load CIFAR-10 into a Ray Dataset.

In [3]:
import ray
from ray.data.datasource import SimpleTorchDatasource
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def train_dataset_factory():
    return torchvision.datasets.CIFAR10(root="./data", download=True, train=True, transform=transform)

def test_dataset_factory():
    return torchvision.datasets.CIFAR10(root="./data", download=True, train=False, transform=transform)

train_dataset: ray.data.Dataset = ray.data.read_datasource(SimpleTorchDatasource(), dataset_factory=train_dataset_factory)
test_dataset: ray.data.Dataset = ray.data.read_datasource(SimpleTorchDatasource(), dataset_factory=test_dataset_factory)

2022-08-18 16:06:42,259	INFO worker.py:1510 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


[2m[36m(_execute_read_task pid=5902)[0m Files already downloaded and verified




[2m[36m(_execute_read_task pid=5902)[0m Files already downloaded and verified


In [4]:
train_dataset

VBox(children=(HTML(value='<h2>Dataset</h2>'), Tab(children=(HTML(value='<div class="scrollableTable jp-Render…

{py:class}`SimpleTorchDatasource <ray.data.datasource.SimpleTorchDatasource>` is unperformant, so you shouldn't use it with larger datasets.

Next, let's represent our data using pandas dataframes instead of tuples. This lets us call {py:meth}`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>` later in the tutorial.

In [5]:
from typing import Tuple
import pandas as pd
import torch


def convert_batch_to_pandas(batch: Tuple[torch.Tensor, int]) -> pd.DataFrame:
    images = [image.numpy() for image, _ in batch]
    labels = [label for _, label in batch]
    return pd.DataFrame({"image": images, "label": labels})


train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

Read->Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[2m[36m(_map_block_nosplit pid=5902)[0m Files already downloaded and verified


Read->Map_Batches: 100%|██████████| 1/1 [00:04<00:00,  4.17s/it]
Read->Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[2m[36m(_map_block_nosplit pid=5902)[0m Files already downloaded and verified


Read->Map_Batches: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it]


In [6]:
train_dataset

VBox(children=(HTML(value='<h2>Dataset</h2>'), Tab(children=(HTML(value='<div class="scrollableTable jp-Render…

## Train a convolutional neural network

Now that we've created our datasets, let's define the training logic.

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

We define our training logic in a function called `train_loop_per_worker`. This function contains regular PyTorch code with a few notable exceptions:
* We wrap our model with {py:func}`train.torch.prepare_model <ray.train.torch.prepare_model>`.
* We call {py:func}`session.get_dataset_shard <ray.air.session.get_dataset_shard>` and {py:meth}`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to get a subset of our training data.
* We save model state using {py:func}`session.report <ray.air.session.report>`.

In [8]:
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
import torchvision


def train_loop_per_worker(config):
    model = train.torch.prepare_model(Net())

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_dataset_shard = session.get_dataset_shard("train")

    for epoch in range(2):
        running_loss = 0.0
        train_dataset_batches = train_dataset_shard.iter_torch_batches(
            batch_size=config["batch_size"],
        )
        for i, batch in enumerate(train_dataset_batches):
            # get the inputs and labels
            inputs, labels = batch["image"], batch["label"]

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0

        metrics = dict(running_loss=running_loss)
        checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())
        session.report(metrics, checkpoint=checkpoint)

Finally, we can train our model. This should take a few minutes to run.

In [9]:
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2)
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc,iter,total time (s),running_loss,_timestamp,_time_this_iter_s
TorchTrainer_74c2c_00000,TERMINATED,127.0.0.1:5917,2,44.585,605.379,1660864062,20.4078


[2m[36m(RayTrainWorker pid=5925)[0m 2022-08-18 16:06:58,973	INFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(RayTrainWorker pid=5925)[0m 2022-08-18 16:07:01,092	INFO train_loop_utils.py:300 -- Moving model to device: cpu
[2m[36m(RayTrainWorker pid=5925)[0m 2022-08-18 16:07:01,092	INFO train_loop_utils.py:347 -- Wrapping provided model in DDP.
[2m[36m(RayTrainWorker pid=5926)[0m   return torch.as_tensor(ndarray, dtype=dtype, device=device)
[2m[36m(RayTrainWorker pid=5925)[0m   return torch.as_tensor(ndarray, dtype=dtype, device=device)


[2m[36m(RayTrainWorker pid=5926)[0m [1,  2000] loss: 2.185
[2m[36m(RayTrainWorker pid=5925)[0m [1,  2000] loss: 2.174
[2m[36m(RayTrainWorker pid=5926)[0m [1,  4000] loss: 1.847
[2m[36m(RayTrainWorker pid=5925)[0m [1,  4000] loss: 1.815
[2m[36m(RayTrainWorker pid=5926)[0m [1,  6000] loss: 1.660
[2m[36m(RayTrainWorker pid=5925)[0m [1,  6000] loss: 1.667
[2m[36m(RayTrainWorker pid=5926)[0m [1,  8000] loss: 1.574
[2m[36m(RayTrainWorker pid=5925)[0m [1,  8000] loss: 1.610
[2m[36m(RayTrainWorker pid=5926)[0m [1, 10000] loss: 1.554
[2m[36m(RayTrainWorker pid=5925)[0m [1, 10000] loss: 1.504
[2m[36m(RayTrainWorker pid=5926)[0m [1, 12000] loss: 1.466
[2m[36m(RayTrainWorker pid=5925)[0m [1, 12000] loss: 1.466
Result for TorchTrainer_74c2c_00000:
  _time_this_iter_s: 20.898104190826416
  _timestamp: 1660864041
  _training_iteration: 1
  date: 2022-08-18_16-07-22
  done: false
  experiment_id: 21f096bb9c144c67b0834951a0bf0506
  hostname: Balajis-MBP.local.meter
 

2022-08-18 16:07:43,240	INFO tune.py:758 -- Total run time: 46.72 seconds (46.59 seconds for the tuning loop).


To scale your training script, create a [Ray Cluster](cluster-index) and increase the number of workers. If your cluster contains GPUs, add `"use_gpu": True` to your scaling config.

```{code-block} python
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
```

## Test the network on the test data

Let's see how our model performs.

To classify images in the test dataset, we'll need to create a {py:class}`Predictor <ray.train.predictor.Predictor>`.

{py:class}`Predictors <ray.train.predictor.Predictor>` load data from checkpoints and efficiently perform inference. In contrast to {py:class}`TorchPredictor <ray.train.torch.TorchPredictor>`, which performs inference on a single batch, {py:class}`BatchPredictor <ray.train.batch_predictor.BatchPredictor>` performs inference on an entire dataset. Because we want to classify all of the images in the test dataset, we'll use a {py:class}`BatchPredictor <ray.train.batch_predictor.BatchPredictor>`.

In [10]:
from ray.train.torch import TorchPredictor
from ray.train.batch_predictor import BatchPredictor

batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TorchPredictor,
    model=Net(),
)

outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, dtype=torch.float, feature_columns=["image"], keep_columns=["label"]
)

[2m[36m(BlockWorker pid=5938)[0m A value is trying to be set on a copy of a slice from a DataFrame.
[2m[36m(BlockWorker pid=5938)[0m Try using .loc[row_indexer,col_indexer] = value instead
[2m[36m(BlockWorker pid=5938)[0m 
[2m[36m(BlockWorker pid=5938)[0m See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
[2m[36m(BlockWorker pid=5938)[0m   df.loc[:, col_name] = TensorArray(col)
[2m[36m(BlockWorker pid=5938)[0m A value is trying to be set on a copy of a slice from a DataFrame.
[2m[36m(BlockWorker pid=5938)[0m Try using .loc[row_indexer,col_indexer] = value instead
[2m[36m(BlockWorker pid=5938)[0m 
[2m[36m(BlockWorker pid=5938)[0m See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
[2m[36m(BlockWorker pid=5938)[0m   df.loc[:, col_name] = TensorArray(col)
[2m[36m(BlockWorker pid=5938)

Our model outputs a list of energies for each class. To classify an image, we
choose the class that has the highest energy.

In [11]:
import numpy as np

def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction", "label"]]

predictions = outputs.map_batches(
    convert_logits_to_classes
)

predictions.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 45.84it/s]

{'prediction': 3, 'label': 3}





Now that we've classified all of the images, let's figure out which images were
classified correctly. The ``predictions`` dataset contains predicted labels and 
the ``test_dataset`` contains the true labels. To determine whether an image 
was classified correctly, we join the two datasets and check if the predicted 
labels are the same as the actual labels.

In [12]:
def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df

scores = predictions.map_batches(calculate_prediction_scores)

scores.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 136.70it/s]

{'prediction': 3, 'label': 3, 'correct': True}





To compute our test accuracy, we'll count how many images the model classified 
correctly and divide that number by the total number of test images.

In [13]:
scores.sum(on="correct") / scores.count()

Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 120.38it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 182.42it/s]


0.5491

## Deploy the network and make a prediction

Our model seems to perform decently, so let's deploy the model to an 
endpoint. This allows us to make predictions over the Internet.

In [14]:
from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_to_ndarray


serve.run(
    PredictorDeployment.bind(
        TorchPredictor,
        latest_checkpoint,
        model=Net(),
        http_adapter=json_to_ndarray,
    )
)

[2m[36m(ServeController pid=5940)[0m INFO 2022-08-18 16:07:45,540 controller 5940 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-b4cd06353d4f06d2ee0c3050e08cdd766d9e23b96c84fb2bacb6455c' on node 'b4cd06353d4f06d2ee0c3050e08cdd766d9e23b96c84fb2bacb6455c' listening on '127.0.0.1:8000'
[2m[36m(HTTPProxyActor pid=5941)[0m INFO:     Started server process [5941]
[2m[36m(ServeController pid=5940)[0m INFO 2022-08-18 16:07:46,180 controller 5940 deployment_state.py:1232 - Adding 1 replica to deployment 'PredictorDeployment'.


RayServeSyncHandle(deployment='PredictorDeployment')

Let's classify a test image.

In [15]:
image = test_dataset.take(1)[0]["image"]

You can perform inference against a deployed model by posting a dictionary with an `"array"` key. To learn more about the default input schema, read the {py:class}`NdArray <ray.serve.http_adapters.NdArray>` documentation.

In [16]:
import requests

payload = {"array": image.tolist(), "dtype": "float32"}
response = requests.post("http://localhost:8000/", json=payload)
response.json()

[-1.1190943717956543,
 -1.9333593845367432,
 1.5133508443832397,
 3.2576074600219727,
 -0.256170392036438,
 1.491028904914856,
 1.0580297708511353,
 0.04123634099960327,
 -2.3149514198303223,
 -2.1277053356170654]

[2m[36m(HTTPProxyActor pid=5941)[0m INFO 2022-08-18 16:07:47,222 http_proxy 127.0.0.1 http_proxy.py:315 - POST / 200 12.7ms
[2m[36m(ServeReplica:PredictorDeployment pid=5942)[0m INFO 2022-08-18 16:07:47,221 PredictorDeployment PredictorDeployment#hSLFCr replica.py:482 - HANDLE __call__ OK 8.1ms
