# Training a TensorFlow Classifier

This tutorial demonstrates how to train an image classifier using the [Ray AI Runtime](air) (AIR).

You should be familiar with TensorFlow before starting this tutorial. If you need a refresher, read TensorFlow's [Convolutional Neural Network](https://www.tensorflow.org/tutorials/images/cnn) tutorial.

## Before you begin

* Install the [Ray AI Runtime](air). You'll need Ray 1.13 or later to run this example.

In [1]:
!pip install 'ray[air]'

You should consider upgrading via the '/Users/balaji/GitHub/ray/.venv/bin/python -m pip install --upgrade pip' command.[0m


* Install `tensorflow` and `tensorflow-datasets`

In [2]:
!pip install tensorflow tensorflow-datasets

You should consider upgrading via the '/Users/balaji/GitHub/ray/.venv/bin/python -m pip install --upgrade pip' command.[0m


## Load and normalize CIFAR-10

We'll train our classifier on a popular image dataset called [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html).

First, let's load CIFAR-10 into a Ray Dataset.

In [3]:
import ray
from ray.data.datasource import SimpleTensorFlowDatasource
import tensorflow as tf

from tensorflow.keras import layers, models
import tensorflow_datasets as tfds


def train_dataset_factory():
    return tfds.load("cifar10", split=["train"], as_supervised=True)[0]

def test_dataset_factory():
    return tfds.load("cifar10", split=["test"], as_supervised=True)[0]

train_dataset = ray.data.read_datasource(
    SimpleTensorFlowDatasource(), dataset_factory=train_dataset_factory
)
test_dataset = ray.data.read_datasource(
    SimpleTensorFlowDatasource(), dataset_factory=test_dataset_factory
)

train_dataset

2022-05-24 13:17:58,949	INFO services.py:1477 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[36m(_execute_read_task pid=4229)[0m 2022-05-24 13:18:03.749540: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(_execute_read_task pid=4229)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Dataset(num_blocks=1, num_rows=50000, schema=<class 'tuple'>)

Note that {py:class}`SimpleTensorFlowDatasource <ray.data.datasource.SimpleTensorFlowDatasource>` loads all data into memory, so you shouldn't use it with larger datasets.

Our model will expect float arrays, so let's normalize pixel values to be between 0 and 1.

In [4]:
def normalize_images(batch):
    return [(tf.cast(image, tf.float32) / 255.0, label) for image, label in batch]

train_dataset = train_dataset.map_batches(normalize_images)
test_dataset = test_dataset.map_batches(normalize_images)

Read->Map_Batches: 100%|██████████| 1/1 [00:12<00:00, 12.63s/it]
Read->Map_Batches: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Next, let's represent our data using pandas dataframes instead of tuples. This lets us call methods like {py:meth}`Dataset.to_tf <ray.data.Dataset.to_tf>` later in the tutorial.

In [5]:
import pandas as pd
from ray.data.extensions import TensorArray


def convert_batch_to_pandas(batch):
    images = TensorArray([image.numpy() for image, _ in batch])
    labels = [label.numpy() for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df


train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

test_dataset

Map_Batches: 100%|██████████| 1/1 [00:04<00:00,  4.37s/it]
Map_Batches: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]


Dataset(num_blocks=1, num_rows=10000, schema={image: TensorDtype, label: int64})

## Train a convolutional neural network

In [6]:
def build_model():
    model = models.Sequential()

    def squeeze(input):
        return tf.squeeze(input, axis=1)

    model.add(layers.Lambda(squeeze))
    model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(16, (5, 5), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(120, activation='relu'))
    model.add(layers.Dense(84, activation='relu'))
    model.add(layers.Dense(10))
    return model

We define our training logic in a function called `train_loop_per_worker`.

`train_loop_per_worker` contains regular TensorFlow code with a few notable exceptions:
* We build and compile our model in the [`MultiWorkerMirrioredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) context.
* We call {py:func}`train.get_dataset_shard <ray.train.get_dataset_shard>` to get a subset of our training data, and call {py:meth}`Dataset.to_tf <ray.data.Dataset.to_tf>` with {py:func}`prepare_dataset_shard <ray.train.tensorflow.prepare_dataset_shard>` to convert the subset to a TensorFlow dataset.
* We save model state using {py:func}`train.save_checkpoint <ray.train.save_checkpoint>`.

In [7]:
from ray import train
from ray.train.tensorflow import prepare_dataset_shard


def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])

    for epoch in range(2):
        tf_dataset = prepare_dataset_shard(
            dataset_shard.to_tf(
                feature_columns=["image"],
                label_column="label",
                output_signature=(
                    tf.TensorSpec(shape=(None, 1, 32, 32, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 1), dtype=tf.uint8),
                ),
                batch_size=config["batch_size"],
                unsqueeze_label_tensor=True,
            )
        )
        model.fit(tf_dataset)
        train.save_checkpoint(epoch=epoch, model=model.get_weights())

Finally, we can train our model. This should take a few minutes to run.

In [8]:
from ray.ml.train.integrations.tensorflow import TensorflowTrainer

trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2}
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc
TensorflowTrainer_af868_00000,TERMINATED,127.0.0.1:4344


[2m[33m(raylet)[0m 2022-05-24 13:18:35,233	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=49715 --object-store-name=/tmp/ray/session_2022-05-24_13-17-56_511828_4178/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-24_13-17-56_511828_4178/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=59018 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:55514 --redis-password=5241590000000000 --startup-token=17 --runtime-env-hash=694422421
[2m[33m(raylet)[0m 2022-05-24 13:18:39,751	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=49715 --object-store-name=/tmp/ray/session_2022-05-24_13

      1/Unknown - 1s 902ms/step - loss: 2.4693 - accuracy: 0.0000e+00
      1/Unknown - 1s 902ms/step - loss: 2.4693 - accuracy: 0.0000e+00
     18/Unknown - 1s 6ms/step - loss: 2.4125 - accuracy: 0.0278     
     18/Unknown - 1s 6ms/step - loss: 2.4125 - accuracy: 0.0278     
     36/Unknown - 1s 6ms/step - loss: 2.3539 - accuracy: 0.0556
     36/Unknown - 1s 6ms/step - loss: 2.3539 - accuracy: 0.0556
     54/Unknown - 1s 6ms/step - loss: 2.3321 - accuracy: 0.0926
     54/Unknown - 1s 6ms/step - loss: 2.3321 - accuracy: 0.0926
     72/Unknown - 1s 6ms/step - loss: 2.3302 - accuracy: 0.0972
     72/Unknown - 1s 6ms/step - loss: 2.3302 - accuracy: 0.0972
     90/Unknown - 1s 6ms/step - loss: 2.3234 - accuracy: 0.1056
     90/Unknown - 1s 6ms/step - loss: 2.3234 - accuracy: 0.1056
    108/Unknown - 2s 6ms/step - loss: 2.3173 - accuracy: 0.0972
    108/Unknown - 2s 6ms/step - loss: 2.3173 - accuracy: 0.0972
    126/Unknown - 2s 6ms/step - loss: 2.3187 - accuracy: 0.1032
    126/Unknown - 

[2m[36m(BaseWorkerMixin pid=4370)[0m 2022-05-24 13:21:18.905825: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(BaseWorkerMixin pid=4371)[0m 2022-05-24 13:21:18.908069: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


     10/Unknown - 0s 6ms/step - loss: 1.4938 - accuracy: 0.4000 
     10/Unknown - 0s 6ms/step - loss: 1.4938 - accuracy: 0.4000 
     28/Unknown - 0s 6ms/step - loss: 1.6684 - accuracy: 0.3929
     28/Unknown - 0s 6ms/step - loss: 1.6684 - accuracy: 0.3929
     46/Unknown - 0s 6ms/step - loss: 1.6534 - accuracy: 0.3696
     46/Unknown - 0s 6ms/step - loss: 1.6534 - accuracy: 0.3696
     64/Unknown - 0s 6ms/step - loss: 1.6486 - accuracy: 0.3828
     64/Unknown - 0s 6ms/step - loss: 1.6486 - accuracy: 0.3828
     79/Unknown - 1s 6ms/step - loss: 1.6483 - accuracy: 0.3861
     79/Unknown - 1s 6ms/step - loss: 1.6483 - accuracy: 0.3861
     97/Unknown - 1s 6ms/step - loss: 1.6038 - accuracy: 0.4175
     97/Unknown - 1s 6ms/step - loss: 1.6038 - accuracy: 0.4175
    115/Unknown - 1s 6ms/step - loss: 1.5585 - accuracy: 0.4348
    115/Unknown - 1s 6ms/step - loss: 1.5585 - accuracy: 0.4348
    133/Unknown - 1s 6ms/step - loss: 1.5252 - accuracy: 0.4398
    133/Unknown - 1s 6ms/step - loss: 

2022-05-24 13:23:59,695	ERROR checkpoint_manager.py:189 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']


Trial TensorflowTrainer_af868_00000 completed. Last result: 


[2m[36m(BaseWorkerMixin pid=4370)[0m E0524 13:23:59.688123000 123145591947264 chttp2_transport.cc:1103]     Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings"
[2m[36m(BaseWorkerMixin pid=4371)[0m E0524 13:23:59.688228000 123145378856960 chttp2_transport.cc:1103]     Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings"




2022-05-24 13:23:59,805	INFO tune.py:752 -- Total run time: 326.05 seconds (325.11 seconds for the tuning loop).


To scale your training script, create a [Ray Cluster](deployment-guide) and increase the number of workers. If your cluster contains GPUs, add `"use_gpu": True` to your scaling config.

```{code-block} python
scaling_config={"num_workers": 8, "use_gpu": True}
```

## Test the network on the test data

Let's see how our model performs.

To classify images in the test dataset, we'll need to create a {py:class}`Predictor <ray.ml.predictor.Predictor>`.

{py:class}`Predictors <ray.ml.predictor.Predictor>` load data from checkpoints and efficiently perform inference. In contrast to {py:class}`TensorflowPredictor <ray.ml.predictors.integrations.tensorflow.TensorflowPredictor>`, which performs inference on a single batch, {py:class}`BatchPredictor <ray.ml.batch_predictor.BatchPredictor>` performs inference on an entire dataset. Because we want to classify all of the images in the test dataset, we'll use a {py:class}`BatchPredictor <ray.ml.batch_predictor.BatchPredictor>`.

In [9]:
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TensorflowPredictor,
    model_definition=build_model,
)


outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"])

outputs.show(1)

Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s][2m[36m(BaseWorkerMixin pid=4370)[0m Exception ignored in: <function Pool.__del__ at 0x1bac040d0>
[2m[36m(BaseWorkerMixin pid=4370)[0m Traceback (most recent call last):
[2m[36m(BaseWorkerMixin pid=4370)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/pool.py", line 268, in __del__
[2m[36m(BaseWorkerMixin pid=4370)[0m     self._change_notifier.put(None)
[2m[36m(BaseWorkerMixin pid=4370)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/queues.py", line 368, in put
[2m[36m(BaseWorkerMixin pid=4370)[0m     self._writer.send_bytes(obj)
[2m[36m(BaseWorkerMixin pid=4370)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
[2m[36m(BaseWorkerMixin pid=4370)[0m     self._send_bytes(m[offset:offset + size])
[2m[36m(BaseWorkerMixin pid=4370)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/pytho

{'predictions': array([-0.72822404, -2.2726588 , -1.0713496 , -0.70978534, -1.4099735 ,
       -1.3684065 , -3.6283069 , -1.8969758 , -0.40979666, -2.5292277 ],
      dtype=float32)}





In [10]:
import numpy as np

def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction"]]

predictions = outputs.map_batches(
    convert_logits_to_classes, batch_format="pandas"
)

predictions.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 28.51it/s]

{'prediction': 8}





In [11]:
def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]

scores = test_dataset.zip(predictions).map_batches(calculate_prediction_scores)

scores.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 18.65it/s]


{'prediction': 8, 'label': 7, 'correct': False}


In [12]:
scores.sum(on="correct") / scores.count()

Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 68.70it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 120.66it/s]


0.4635

## Deploy the network and make a prediction

TODO

In [None]:
from ray import serve
from ray.serve.model_wrappers import ModelWrapperDeployment
serve.start(detached=True)
deployment = ModelWrapperDeployment.options(name="my-deployment")
deployment.deploy(TensorflowPredictor, latest_checkpoint, batching_params=False, model_definition=build_model)

batch = test_dataset.take(1)
array = np.array(batch[0]["image"])