# Hello Image Data

This tutorial demonstrates how to train an image classifier using TensorFlow and the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html).

You should be familiar with TensorFlow before starting this tutorial. If you need a refresher, read TensorFlow's [Convolutional Neural Network](https://www.tensorflow.org/tutorials/images/cnn) tutorial.

## Before you begin

* Install the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html). You'll need Ray 1.13 later to run this example.

```
pip instsall 'ray[data,tune]'
```

* Install `tensorflow` and `tensorflow-datasets`

```
pip install tensorflow tensorflow-datasets
```


# Load and normalize CIFAR-10

In [1]:
import ray
from ray.data.datasource import SimpleTensorFlowDatasource
import tensorflow as tf

from tensorflow.keras import layers, models
import tensorflow_datasets as tfds


def train_dataset_factory():
    return tfds.load("cifar10", split=["train"], as_supervised=True)[0]

def test_dataset_factory():
    return tfds.load("cifar10", split=["test"], as_supervised=True)[0]

train_dataset = ray.data.read_datasource(  
    SimpleTensorFlowDatasource(), dataset_factory=train_dataset_factory
)
test_dataset = ray.data.read_datasource(SimpleTensorFlowDatasource(), dataset_factory=test_dataset_factory)

train_dataset

2022-05-22 21:39:13,297	INFO services.py:1478 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[33m(raylet)[0m E0522 21:39:15.159292000 4659019264 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(_execute_read_task pid=8050)[0m 2022-05-22 21:39:18.712024: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(_execute_read_task pid=8050)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def normalize_images(batch):
    return [(tf.cast(image, tf.float32) / 255.0, label) for image, label in batch]

train_dataset = train_dataset.map_batches(normalize_images)
test_dataset = test_dataset.map_batches(normalize_images)

E0522 21:39:28.881343000 4745436672 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
Read->Map_Batches: 100%|██████████| 1/1 [00:13<00:00, 13.24s/it]
Read->Map_Batches: 100%|██████████| 1/1 [00:02<00:00,  2.78s/it]


In [3]:
import pandas as pd
from ray.data.extensions import TensorArray


def convert_batch_to_pandas(batch):
    images = TensorArray([image.numpy() for image, _ in batch])
    labels = [label.numpy() for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df
    

train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

test_dataset

Map_Batches: 100%|██████████| 1/1 [00:04<00:00,  4.41s/it]
Map_Batches: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


Dataset(num_blocks=1, num_rows=10000, schema={image: TensorDtype, label: int64})

## Train a convolutional neural network

In [4]:
def build_model():
    model = models.Sequential()
    
    def squeeze(input):
        return tf.squeeze(input, axis=1)

    model.add(layers.Lambda(squeeze))
    model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(16, (5, 5), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(120, activation='relu'))
    model.add(layers.Dense(84, activation='relu'))
    model.add(layers.Dense(10))
    return model

In [5]:
from ray import train
from ray.train.tensorflow import prepare_dataset_shard


def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    
    for epoch in range(2):
        tf_dataset = prepare_dataset_shard(
            dataset_shard.to_tf(
                feature_columns=["image"],
                label_column="label",
                output_signature=(
                    tf.TensorSpec(shape=(None, 1, 32, 32, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 1), dtype=tf.uint8),
                ),
                batch_size=config["batch_size"],
                unsqueeze_label_tensor=True,
            )
        )
        model.fit(tf_dataset)
        train.save_checkpoint(epoch=epoch, model=model.get_weights())

In [6]:
from ray.ml.train.integrations.tensorflow import TensorflowTrainer

trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2}
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc
TensorflowTrainer_61c87_00000,TERMINATED,127.0.0.1:8122


[2m[33m(raylet)[0m 2022-05-22 21:39:51,760	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=50416 --object-store-name=/tmp/ray/session_2022-05-22_21-39-10_906221_7984/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-22_21-39-10_906221_7984/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=60617 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:60010 --redis-password=5241590000000000 --startup-token=17 --runtime-env-hash=1215741992
[2m[33m(raylet)[0m 2022-05-22 21:39:57,049	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=5

      1/Unknown - 3s 3s/step - loss: 2.3049 - accuracy: 0.0000e+00
      1/Unknown - 3s 3s/step - loss: 2.3049 - accuracy: 0.0000e+00
     17/Unknown - 3s 7ms/step - loss: 2.3260 - accuracy: 0.0000e+00
     17/Unknown - 3s 7ms/step - loss: 2.3260 - accuracy: 0.0000e+00
     35/Unknown - 3s 6ms/step - loss: 2.3172 - accuracy: 0.0286   
     35/Unknown - 3s 6ms/step - loss: 2.3172 - accuracy: 0.0286   
     59/Unknown - 3s 6ms/step - loss: 2.3113 - accuracy: 0.0678
     59/Unknown - 3s 6ms/step - loss: 2.3113 - accuracy: 0.0678
     75/Unknown - 4s 6ms/step - loss: 2.3095 - accuracy: 0.0733
     75/Unknown - 4s 6ms/step - loss: 2.3095 - accuracy: 0.0733
     91/Unknown - 4s 6ms/step - loss: 2.3082 - accuracy: 0.0824
     91/Unknown - 4s 6ms/step - loss: 2.3082 - accuracy: 0.0824
    109/Unknown - 4s 6ms/step - loss: 2.3062 - accuracy: 0.0917
    109/Unknown - 4s 6ms/step - loss: 2.3062 - accuracy: 0.0917
    126/Unknown - 4s 6ms/step - loss: 2.3064 - accuracy: 0.0992
    126/Unknown - 4s

[2m[36m(BaseWorkerMixin pid=8138)[0m 2022-05-22 21:43:05.700030: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
[2m[36m(BaseWorkerMixin pid=8137)[0m 2022-05-22 21:43:05.691877: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


      8/Unknown - 0s 7ms/step - loss: 1.4151 - accuracy: 0.5000 
      8/Unknown - 0s 7ms/step - loss: 1.4151 - accuracy: 0.5000 
     24/Unknown - 0s 7ms/step - loss: 1.4879 - accuracy: 0.4792
     24/Unknown - 0s 7ms/step - loss: 1.4879 - accuracy: 0.4792
     39/Unknown - 0s 7ms/step - loss: 1.5126 - accuracy: 0.4487
     39/Unknown - 0s 7ms/step - loss: 1.5126 - accuracy: 0.4487
     54/Unknown - 0s 7ms/step - loss: 1.4701 - accuracy: 0.4722
     54/Unknown - 0s 7ms/step - loss: 1.4701 - accuracy: 0.4722
     70/Unknown - 0s 7ms/step - loss: 1.4467 - accuracy: 0.4714
     70/Unknown - 0s 7ms/step - loss: 1.4467 - accuracy: 0.4714
     78/Unknown - 1s 7ms/step - loss: 1.4312 - accuracy: 0.4872
     78/Unknown - 1s 7ms/step - loss: 1.4312 - accuracy: 0.4872
     94/Unknown - 1s 7ms/step - loss: 1.4171 - accuracy: 0.4840
     94/Unknown - 1s 7ms/step - loss: 1.4171 - accuracy: 0.4840
    111/Unknown - 1s 7ms/step - loss: 1.4217 - accuracy: 0.4955
    111/Unknown - 1s 7ms/step - loss: 

2022-05-22 21:46:04,256	ERROR checkpoint_manager.py:189 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']


Trial TensorflowTrainer_61c87_00000 completed. Last result: 


2022-05-22 21:46:04,368	INFO tune.py:752 -- Total run time: 373.97 seconds (373.26 seconds for the tuning loop).


## Test the network on the test data

In [7]:
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TensorflowPredictor,
    model_definition=build_model,
)

    
outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"])
    
outputs.show(1)


Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s][2m[36m(BaseWorkerMixin pid=8138)[0m Exception ignored in: <function Pool.__del__ at 0x1bea98310>
[2m[36m(BaseWorkerMixin pid=8138)[0m Traceback (most recent call last):
[2m[36m(BaseWorkerMixin pid=8138)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/pool.py", line 268, in __del__
[2m[36m(BaseWorkerMixin pid=8138)[0m     self._change_notifier.put(None)
[2m[36m(BaseWorkerMixin pid=8138)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/queues.py", line 368, in put
[2m[36m(BaseWorkerMixin pid=8138)[0m     self._writer.send_bytes(obj)
[2m[36m(BaseWorkerMixin pid=8138)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
[2m[36m(BaseWorkerMixin pid=8138)[0m     self._send_bytes(m[offset:offset + size])
[2m[36m(BaseWorkerMixin pid=8138)[0m   File "/Users/balaji/.pyenv/versions/3.8.12/lib/pytho

{'predictions': array([-3.0800383 , -3.6460454 , -0.86869377,  0.7446567 , -0.67810845,
        0.90950084, -0.42219412, -0.3526248 , -2.575988  , -3.0471063 ],
      dtype=float32)}





In [8]:
def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction"]]

predictions = outputs.map_batches(
    convert_logits_to_classes, batch_format="pandas"
)

predictions.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 28.83it/s]

{'prediction': 5}





In [9]:
def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]

scores = test_dataset.zip(predictions).map_batches(calculate_prediction_scores)

scores.show(1)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 19.64it/s]

{'prediction': 5, 'label': 7, 'correct': False}





In [10]:
scores.sum(on="correct") / scores.count()

Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 64.57it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 119.22it/s]


0.5007