# Training a TensorFlow Classifier

This tutorial demonstrates how to train an image classifier using the [Ray AI Runtime](air) (AIR).

You should be familiar with TensorFlow before starting this tutorial. If you need a refresher, read TensorFlow's [Convolutional Neural Network](https://www.tensorflow.org/tutorials/images/cnn) tutorial.

## Before you begin

* Install the [Ray AI Runtime](air). You'll need Ray 1.13 or later to run this example.

In [1]:
!pip install 'ray[air]'



* Install `tensorflow` and `tensorflow-datasets`

In [2]:
!pip install tensorflow tensorflow-datasets



## Load and normalize CIFAR-10

We'll train our classifier on a popular image dataset called [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html).

First, let's load CIFAR-10 into a Ray {py:class}`Dataset <ray.data.Dataset>`.
This'll allow us to train on multiple machines.

To load a TensorFlow dataset into a Ray {py:class}`Dataset <ray.data.Dataset>`, we:

1. Define a factory function that creates and returns a TensorFlow dataset.
2. Call {py:func}`read_datasource <ray.data.read_datasource>` and pass in our factory function.

In [3]:
import ray
from ray.data.datasource import SimpleTensorFlowDatasource
import tensorflow as tf

from tensorflow.keras import layers, models
import tensorflow_datasets as tfds


def train_dataset_factory():
    return tfds.load("cifar10", split=["train"], as_supervised=True)[0]


def test_dataset_factory():
    return tfds.load("cifar10", split=["test"], as_supervised=True)[0]


train_dataset = ray.data.read_datasource(
    SimpleTensorFlowDatasource(), dataset_factory=train_dataset_factory
)
test_dataset = ray.data.read_datasource(
    SimpleTensorFlowDatasource(), dataset_factory=test_dataset_factory
)

train_dataset

2022-06-06 12:40:22,821	INFO services.py:1477 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[36m(_execute_read_task pid=7172)[0m 2022-06-06 12:40:27.818305: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(_execute_read_task pid=7172)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Dataset(num_blocks=1, num_rows=50000, schema=<class 'tuple'>)

Note that {py:class}`SimpleTensorFlowDatasource <ray.data.datasource.SimpleTensorFlowDatasource>`
loads all data into memory, so you shouldn't use it with larger datasets.

Our model will expect float arrays, so let's normalize pixel values to be between 0 and 1.

In [4]:
def normalize_images(batch):
    return [(tf.cast(image, tf.float32) / 255.0, label) for image, label in batch]


train_dataset = train_dataset.map_batches(normalize_images)
test_dataset = test_dataset.map_batches(normalize_images)

Read->Map_Batches: 100%|██████████| 1/1 [00:12<00:00, 12.66s/it]
Read->Map_Batches: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Next, let's represent our data using Pandas DataFrames instead of tuples. This lets us call methods like {py:meth}`Dataset.to_tf <ray.data.Dataset.to_tf>` later in the tutorial.

In [5]:
import pandas as pd
from ray.data.extensions import TensorArray


def convert_batch_to_pandas(batch):
    images = TensorArray([image.numpy() for image, _ in batch])
    labels = [label.numpy() for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df


train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

test_dataset

Map_Batches: 100%|██████████| 1/1 [00:04<00:00,  4.14s/it]
Map_Batches: 100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


Dataset(num_blocks=1, num_rows=10000, schema={image: TensorDtype, label: int64})

## Train a convolutional neural network

Now that we've created our datasets, let's define the training logic.

In [6]:
def build_model():
    model = models.Sequential()
    model.add(layers.Conv2D(6, (5, 5), activation="relu", input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(16, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(120, activation="relu"))
    model.add(layers.Dense(84, activation="relu"))
    model.add(layers.Dense(10))
    return model

We define our training logic in a function called `train_loop_per_worker`.

`train_loop_per_worker` contains regular TensorFlow code with a few notable exceptions:
* We build and compile our model in the [`MultiWorkerMirrioredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) context.
* We call {py:func}`train.get_dataset_shard <ray.train.get_dataset_shard>` to get a subset of our training data, and call {py:meth}`Dataset.to_tf <ray.data.Dataset.to_tf>` with {py:func}`prepare_dataset_shard <ray.train.tensorflow.prepare_dataset_shard>` to convert the subset to a TensorFlow dataset.
* We save model state using {py:func}`train.save_checkpoint <ray.train.save_checkpoint>`.

In [7]:
from ray import train
from ray.train.tensorflow import prepare_dataset_shard


def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
        )

    for epoch in range(2):
        tf_dataset = prepare_dataset_shard(
            dataset_shard.to_tf(
                feature_columns=["image"],
                label_column="label",
                output_signature=(
                    tf.TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None), dtype=tf.uint8),
                ),
                batch_size=config["batch_size"],
            )
        )
        model.fit(tf_dataset)
        train.save_checkpoint(epoch=epoch, model=model.get_weights())

Finally, we can train our model. This should take a few minutes to run.

In [8]:
from ray.air.train.integrations.tensorflow import TensorflowTrainer

trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2},
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc
TensorflowTrainer_95ef0_00000,TERMINATED,127.0.0.1:7286


[2m[36m(BaseWorkerMixin pid=7291)[0m Instructions for updating:
[2m[36m(BaseWorkerMixin pid=7291)[0m use distribute.MultiWorkerMirroredStrategy instead
[2m[36m(BaseWorkerMixin pid=7291)[0m 2022-06-06 12:41:09.019303: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(BaseWorkerMixin pid=7291)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(BaseWorkerMixin pid=7291)[0m 2022-06-06 12:41:09.023889: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> 127.0.0.1:51331, 1 -> 127.0.0.1:51332}
[2m[36m(BaseWorkerMixin pid=7291)[0m 2022-06-06 12:41:09.023981: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> 127.0.0.1:51331

      1/Unknown - 3s 3s/step - loss: 2.3304 - sparse_categorical_accuracy: 0.0000e+00
      1/Unknown - 3s 3s/step - loss: 2.3304 - sparse_categorical_accuracy: 0.0000e+00
     18/Unknown - 3s 6ms/step - loss: 2.3186 - sparse_categorical_accuracy: 0.1111 
     18/Unknown - 3s 6ms/step - loss: 2.3186 - sparse_categorical_accuracy: 0.1111 
     35/Unknown - 3s 6ms/step - loss: 2.3118 - sparse_categorical_accuracy: 0.1000
     35/Unknown - 3s 6ms/step - loss: 2.3118 - sparse_categorical_accuracy: 0.1000
     53/Unknown - 3s 6ms/step - loss: 2.3130 - sparse_categorical_accuracy: 0.1038
     53/Unknown - 3s 6ms/step - loss: 2.3130 - sparse_categorical_accuracy: 0.1038
     71/Unknown - 3s 6ms/step - loss: 2.3128 - sparse_categorical_accuracy: 0.1056
     71/Unknown - 3s 6ms/step - loss: 2.3128 - sparse_categorical_accuracy: 0.1056
     89/Unknown - 4s 6ms/step - loss: 2.3099 - sparse_categorical_accuracy: 0.1067
     89/Unknown - 4s 6ms/step - loss: 2.3099 - sparse_categorical_accuracy: 0.1

[2m[36m(BaseWorkerMixin pid=7291)[0m 2022-06-06 12:43:48.030589: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(BaseWorkerMixin pid=7292)[0m 2022-06-06 12:43:48.025996: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


      1/Unknown - 0s 19ms/step - loss: 1.6563 - sparse_categorical_accuracy: 0.5000
      1/Unknown - 0s 19ms/step - loss: 1.6563 - sparse_categorical_accuracy: 0.5000
     18/Unknown - 0s 6ms/step - loss: 1.7888 - sparse_categorical_accuracy: 0.3889
     18/Unknown - 0s 6ms/step - loss: 1.7888 - sparse_categorical_accuracy: 0.3889
     36/Unknown - 0s 6ms/step - loss: 1.6368 - sparse_categorical_accuracy: 0.4028
     36/Unknown - 0s 6ms/step - loss: 1.6368 - sparse_categorical_accuracy: 0.4028
     54/Unknown - 0s 6ms/step - loss: 1.5942 - sparse_categorical_accuracy: 0.4722
     54/Unknown - 0s 6ms/step - loss: 1.5942 - sparse_categorical_accuracy: 0.4722
     70/Unknown - 0s 6ms/step - loss: 1.6229 - sparse_categorical_accuracy: 0.4643
     70/Unknown - 0s 6ms/step - loss: 1.6229 - sparse_categorical_accuracy: 0.4643
     87/Unknown - 1s 6ms/step - loss: 1.6228 - sparse_categorical_accuracy: 0.4655
     87/Unknown - 1s 6ms/step - loss: 1.6228 - sparse_categorical_accuracy: 0.4655
  

2022-06-06 12:46:29,017	ERROR checkpoint_manager.py:189 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']


Trial TensorflowTrainer_95ef0_00000 completed. Last result: 


2022-06-06 12:46:29,128	INFO tune.py:741 -- Total run time: 331.84 seconds (331.01 seconds for the tuning loop).


To scale your training script, create a [Ray Cluster](deployment-guide) and increase the number of workers. If your cluster contains GPUs, add `"use_gpu": True` to your scaling config.

```{code-block} python
scaling_config={"num_workers": 8, "use_gpu": True}
```

## Test the network on the test data

Let's see how our model performs.

To classify images in the test dataset, we'll need to create a {py:class}`Predictor <ray.ml.predictor.Predictor>`.

{py:class}`Predictors <ray.ml.predictor.Predictor>` load data from checkpoints and efficiently perform inference. In contrast to {py:class}`TensorflowPredictor <ray.ml.predictors.integrations.tensorflow.TensorflowPredictor>`, which performs inference on a single batch, {py:class}`BatchPredictor <ray.ml.batch_predictor.BatchPredictor>` performs inference on an entire dataset. Because we want to classify all of the images in the test dataset, we'll use a {py:class}`BatchPredictor <ray.ml.batch_predictor.BatchPredictor>`.

In [9]:
from ray.air.predictors.integrations.tensorflow import TensorflowPredictor
from ray.air.batch_predictor import BatchPredictor

batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TensorflowPredictor,
    model_definition=build_model,
)


outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"]
)

[2m[36m(BaseWorkerMixin pid=7292)[0m 2022-06-06 12:46:29.145770: E tensorflow/core/common_runtime/base_collective_executor.cc:249] BaseCollectiveExecutor::StartAbort UNAVAILABLE: failed to connect to all addresses
[2m[36m(BaseWorkerMixin pid=7292)[0m Additional GRPC error information from remote target /job:worker/replica:0/task:0:
[2m[36m(BaseWorkerMixin pid=7292)[0m :{"created":"@1654544789.145672000","description":"Failed to pick subchannel","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":3941,"referenced_errors":[{"created":"@1654544789.140738000","description":"failed to connect to all addresses","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":393,"grpc_status":14}]}
Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s][2m[36m(BaseWorkerMixin pid=7292)[0m Exception ignored in: <function Pool.__del__ at 0x1bbb14430>
[2m[36m(BaseWorkerMixin pi

Our models outputs a list of energies for each class. To classify an image, we
choose the class that has the highest energy.

In [10]:
import numpy as np


def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction"]]


predictions = outputs.map_batches(convert_logits_to_classes, batch_format="pandas")

predictions.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 13.04it/s]

{'prediction': 3}





Now that we've classified all of the images, let's figure out which images were
classified correctly. The ``predictions`` dataset contains predicted labels and
the ``test_dataset`` contains the true labels. To determine whether an image
was classified correctly, we join the two datasets and check if the predicted
labels are the same as the actual labels.

In [11]:
def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]


scores = test_dataset.zip(predictions).map_batches(calculate_prediction_scores)

scores.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 18.01it/s]

{'prediction': 3, 'label': 7, 'correct': False}





To compute our test accuracy, we'll count how many images the model classified
correctly and divide that number by the total number of test images.

In [12]:
scores.sum(on="correct") / scores.count()

Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 62.70it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 103.05it/s]


0.4764

## Deploy the network and make a prediction

Our model seems to perform decently, so let's deploy the model to an
endpoint. This'll allow us to make predictions over the Internet.

In [13]:
from ray import serve
from ray.serve.model_wrappers import ModelWrapperDeployment

serve.start(detached=True)
deployment = ModelWrapperDeployment.options(name="my-deployment")
deployment.deploy(
    TensorflowPredictor,
    latest_checkpoint,
    batching_params=False,
    model_definition=build_model,
)

[2m[36m(ServeController pid=7834)[0m INFO 2022-06-06 12:46:35,253 controller 7834 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.
[2m[36m(ServeController pid=7834)[0m INFO 2022-06-06 12:46:35,257 controller 7834 http_state.py:112 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-node:127.0.0.1-0' on node 'node:127.0.0.1-0' listening on '127.0.0.1:8000'
[2m[36m(HTTPProxyActor pid=7842)[0m INFO:     Started server process [7842]
[2m[36m(ServeController pid=7834)[0m INFO 2022-06-06 12:46:39,068 controller 7834 deployment_state.py:1220 - Adding 1 replicas to deployment 'my-deployment'.


Let's classify a test image.

In [14]:
batch = test_dataset.take(1)
array = np.expand_dims(np.array(batch[0]["image"]), axis=0)

You can perform inference against a deployed model by posting a dictionary with an `"array"` key. To learn more about the default input schema, read the {py:class}`NdArray <ray.serve.http_adapters.NdArray>` documentation.

In [15]:
import requests

payload = {"array": array.tolist()}
response = requests.post(deployment.url, json=payload)
response.json()

[2m[36m(HTTPProxyActor pid=7842)[0m INFO 2022-06-06 12:46:43,215 http_proxy 127.0.0.1 http_proxy.py:315 - POST /my-deployment 307 5.4ms
[2m[36m(my-deployment pid=7853)[0m INFO 2022-06-06 12:46:43,213 my-deployment my-deployment#GPJOfT replica.py:479 - HANDLE __call__ OK 0.4ms


{'predictions': {'0': [-0.9284152984619141,
   -1.5676860809326172,
   -0.9705678224563599,
   0.6415643692016602,
   0.16386678814888,
   0.04367314279079437,
   -0.26028507947921753,
   -0.5868486166000366,
   -0.9341723322868347,
   -1.9672366380691528]}}

[2m[36m(HTTPProxyActor pid=7842)[0m INFO 2022-06-06 12:46:43,321 http_proxy 127.0.0.1 http_proxy.py:315 - POST /my-deployment 200 103.7ms
[2m[36m(my-deployment pid=7853)[0m 2022-06-06 12:46:43.240502: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(my-deployment pid=7853)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(my-deployment pid=7853)[0m INFO 2022-06-06 12:46:43,320 my-deployment my-deployment#GPJOfT replica.py:479 - HANDLE __call__ OK 100.9ms
