# Hello Image Data

This tutorial demonstrates how to train an image classifier using TensorFlow and the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html).

You should be familiar with TensorFlow before starting this tutorial. If you need a refresher, read TensorFlow's [Convolutional Neural Network](https://www.tensorflow.org/tutorials/images/cnn) tutorial.

## Before you begin

* Install the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html). You'll need Ray 1.13 later to run this example.

```
pip instsall 'ray[data,tune]'
```

* Install `tensorflow` and `tensorflow-datasets`

```
pip install tensorflow tensorflow-datasets
```


# Load and normalize CIFAR-10

In [1]:
import ray
from ray.data.datasource import SimpleTensorFlowDatasource
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

def train_dataset_factory():
    return tfds.load("cifar10", split=["train"], as_supervised=True)[0]

def test_dataset_factory():
    return tfds.load("cifar10", split=["test"], as_supervised=True)[0]

# train_dataset = ray.data.read_datasource(  
#     SimpleTensorFlowDatasource(), dataset_factory=train_dataset_factory
# )
test_dataset = ray.data.read_datasource(SimpleTensorFlowDatasource(), dataset_factory=test_dataset_factory)


2022-05-19 16:39:57,661	INFO services.py:1478 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m
[2m[33m(raylet)[0m E0519 16:40:02.259298000 4728241664 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(_execute_read_task pid=7913)[0m 2022-05-19 16:40:11.253324: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(_execute_read_task pid=7913)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def normalize_images(batch):
    return [(tf.cast(image, tf.float32) / 255.0, label) for image, label in batch]

# train_dataset = train_dataset.map_batches(normalize_images)
test_dataset = test_dataset.map_batches(normalize_images)

E0519 16:40:18.572226000 4695694848 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
Read->Map_Batches: 100%|██████████| 1/1 [00:08<00:00,  8.94s/it]


In [3]:
import pandas as pd
from ray.data.extensions import TensorArray


def convert_batch_to_pandas(batch):
    images = TensorArray(tf.stack([image for image, _ in batch]).numpy())  # Can I [TensorArray(...)]?
    labels = [label.numpy() for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df
    

# train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

test_dataset

Map_Batches: 100%|██████████| 1/1 [00:02<00:00,  2.09s/it]


Dataset(num_blocks=1, num_rows=10000, schema={image: TensorDtype, label: int64})

## Train a convolutional neural network

In [4]:
def build_model():
    model = models.Sequential()
    def squeeze(input):
        print(input.shape)
        return tf.squeeze(input, axis=1)
    model.add(layers.Lambda(squeeze))
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(32 * 32 * 3, activation='relu'))
    model.add(layers.Dense(10))
    return model

In [5]:
from ray import train
from ray.train.tensorflow import prepare_dataset_shard


# Slower than Torch?

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    
    for epoch in range(2):  # TODO: Change to 2 epochs
        tf_dataset = prepare_dataset_shard(
            dataset_shard.to_tf(
                feature_columns=["image"],
                label_column="label",
                output_signature=(
                    tf.TensorSpec(shape=(None, 1, 32, 32, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 1), dtype=tf.uint8),
                ),
                batch_size=config["batch_size"],
                unsqueeze_label_tensor=True,
            )
        )
        model.fit(tf_dataset)
        train.save_checkpoint(epoch=epoch, model_weights=model.get_weights())

In [6]:
from ray.ml.train.integrations.tensorflow import TensorflowTrainer

trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2}
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc
TensorflowTrainer_30d0f_00000,RUNNING,127.0.0.1:7436


[2m[33m(raylet)[0m 2022-05-19 16:34:16,213	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51427 --object-store-name=/tmp/ray/session_2022-05-19_16-32-31_101745_7206/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_16-32-31_101745_7206/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=63854 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:51415 --redis-password=5241590000000000 --startup-token=17 --runtime-env-hash=1215741992
[2m[33m(raylet)[0m 2022-05-19 16:34:30,272	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=5

      1/Unknown - 5s 5s/step - loss: 2.2859 - accuracy: 0.0000e+00
      1/Unknown - 5s 5s/step - loss: 2.2859 - accuracy: 0.0000e+00
      2/Unknown - 5s 139ms/step - loss: 2.5290 - accuracy: 0.0000e+00
      2/Unknown - 5s 138ms/step - loss: 2.5290 - accuracy: 0.0000e+00
      3/Unknown - 6s 142ms/step - loss: 2.3747 - accuracy: 0.0000e+00
      3/Unknown - 6s 142ms/step - loss: 2.3747 - accuracy: 0.0000e+00
      4/Unknown - 6s 133ms/step - loss: 2.3678 - accuracy: 0.0000e+00
      4/Unknown - 6s 133ms/step - loss: 2.3678 - accuracy: 0.0000e+00
      5/Unknown - 6s 133ms/step - loss: 2.4282 - accuracy: 0.0000e+00
      5/Unknown - 6s 132ms/step - loss: 2.4282 - accuracy: 0.0000e+00
      6/Unknown - 6s 130ms/step - loss: 2.4417 - accuracy: 0.0000e+00
      6/Unknown - 6s 130ms/step - loss: 2.4417 - accuracy: 0.0000e+00
      7/Unknown - 6s 127ms/step - loss: 2.4206 - accuracy: 0.0714    
      7/Unknown - 6s 127ms/step - loss: 2.4206 - accuracy: 0.0714    
      8/Unknown - 6s 129ms

## Test the network on the test data

In [None]:
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TensorflowPredictor,
    model=Net(),
)
    
outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"], unsqueeze=False
)
outputs.show(1)

# Save checkpoint to file?

## What's next

TODO