# Hello Image Data

This tutorial demonstrates how to train an image classifier using TensorFlow and the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html).

You should be familiar with TensorFlow before starting this tutorial. If you need a refresher, read TensorFlow's [Convolutional Neural Network](https://www.tensorflow.org/tutorials/images/cnn) tutorial.

## Before you begin

* Install the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html). You'll need Ray 1.13 later to run this example.

```
pip instsall 'ray[data,tune]'
```

* Install `tensorflow` and `tensorflow-datasets`

```
pip install tensorflow tensorflow-datasets
```


# Load and normalize CIFAR-10

In [20]:
import ray
from ray.data.datasource import SimpleTensorFlowDatasource
import tensorflow as tf

from tensorflow.keras import layers, models
import tensorflow_datasets as tfds

def train_dataset_factory():
    return tfds.load("cifar10", split=["train"], as_supervised=True)[0]

def test_dataset_factory():
    return tfds.load("cifar10", split=["test"], as_supervised=True)[0]

train_dataset = ray.data.read_datasource(  
    SimpleTensorFlowDatasource(), dataset_factory=train_dataset_factory
)
test_dataset = ray.data.read_datasource(SimpleTensorFlowDatasource(), dataset_factory=test_dataset_factory)




In [21]:
def normalize_images(batch):
    return [(tf.cast(image, tf.float32) / 255.0, label) for image, label in batch]

train_dataset = train_dataset.map_batches(normalize_images)
test_dataset = test_dataset.map_batches(normalize_images)

Read->Map_Batches: 100%|██████████| 1/1 [00:15<00:00, 15.09s/it]
Read->Map_Batches: 100%|██████████| 1/1 [00:02<00:00,  2.64s/it]


In [3]:
import pandas as pd
from ray.data.extensions import TensorArray


def convert_batch_to_pandas(batch):
    images = [image.numpy() for image, _ in batch]
    labels = [label.numpy() for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df
    

train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)

test_dataset

Map_Batches: 100%|██████████| 1/1 [00:04<00:00,  4.41s/it]
Map_Batches: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]


Dataset(num_blocks=1, num_rows=10000, schema={image: object, label: int64})

## Train a convolutional neural network

In [22]:
def build_model():
    model = models.Sequential()
    # def squeeze(input):
    #     print(input.shape)
    #     return tf.squeeze(input, axis=1)
    # model.add(layers.Lambda(squeeze))
    model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(16, (5, 5), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(120, activation='relu'))
    model.add(layers.Dense(84, activation='relu'))
    model.add(layers.Dense(10))
    return model

In [10]:
from ray import train
from ray.train.tensorflow import prepare_dataset_shard


# Slower than Torch?

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    
    for epoch in range(2):  # TODO: Change to 2 epochs
        tf_dataset = prepare_dataset_shard(
            dataset_shard.to_tf(
                feature_columns=["image"],
                label_column="label",
                output_signature=(
                    tf.TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(None, 1), dtype=tf.uint8),
                ),
                batch_size=config["batch_size"],
                unsqueeze_label_tensor=True,
            )
        )
        model.fit(tf_dataset)
        train.save_checkpoint(epoch=epoch, model_weights=model.get_weights())

In [19]:
from ray.ml.train.integrations.tensorflow import TensorflowTrainer

trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2}
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

Trial name,status,loc
TensorflowTrainer_cde45_00000,ERROR,127.0.0.1:15129

Trial name,# failures,error file
TensorflowTrainer_cde45_00000,1,/Users/balaji/ray_results/TensorflowTrainer_2022-05-21_16-14-46/TensorflowTrainer_cde45_00000_0_2022-05-21_16-14-46/error.txt


[2m[33m(raylet)[0m 2022-05-21 16:14:46,887	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=52605 --object-store-name=/tmp/ray/session_2022-05-21_16-05-51_439235_13824/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-21_16-05-51_439235_13824/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=64222 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:56352 --redis-password=5241590000000000 --startup-token=36 --runtime-env-hash=1215741992
[2m[33m(raylet)[0m 2022-05-21 16:14:52,082	INFO context.py:70 -- Exec'ing worker with command: exec /Users/balaji/GitHub/ray/.venv/bin/python /Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port

[2m[36m(BaseWorkerMixin pid=15153)[0m object
[2m[36m(BaseWorkerMixin pid=15153)[0m object
[2m[36m(BaseWorkerMixin pid=15153)[0m object
[2m[36m(BaseWorkerMixin pid=15153)[0m object


[2m[36m(BaseWorkerMixin pid=15153)[0m 2022-05-21 16:15:00.057129: W tensorflow/core/framework/op_kernel.cc:1733] INVALID_ARGUMENT: TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [[array([[[0.2509804 , 0.2509804 , 0.24313726],
[2m[36m(BaseWorkerMixin pid=15153)[0m           [0.4627451 , 0.45882353, 0.41960785],
[2m[36m(BaseWorkerMixin pid=15153)[0m           [0.2509804 , 0.24705882, 0.23921569],
[2m[36m(BaseWorkerMixin pid=15153)[0m           ...,
[2m[36m(BaseWorkerMixin pid=15153)[0m           [0.39215687, 0.4117647 , 0.32156864],
[2m[36m(BaseWorkerMixin pid=15153)[0m           [0.4117647 , 0.42352942, 0.32156864],
[2m[36m(BaseWorkerMixin pid=15153)[0m           [0.24705882, 0.25882354, 0.22745098]],
[2m[36m(BaseWorkerMixin pid=15153)[0m 
[2m[36m(BaseWorkerMixin pid=15153)[0m          [[0.23529412, 0.23529412, 0.22745098],
[2m[36m(BaseWorkerMixin pid=1515

[2m[36m(BaseWorkerMixin pid=15152)[0m object
[2m[36m(BaseWorkerMixin pid=15152)[0m object
[2m[36m(BaseWorkerMixin pid=15152)[0m object
[2m[36m(BaseWorkerMixin pid=15152)[0m object
Result for TensorflowTrainer_cde45_00000:
  date: 2022-05-21_16-14-51
  experiment_id: ac2943f4b02e476aab61e6416ae944d3
  hostname: Balajis-MacBook-Pro-16-inch-2019
  node_ip: 127.0.0.1
  pid: 15129
  timestamp: 1653174891
  trial_id: cde45_00000
  


[2m[36m(TrainTrainable pid=15129)[0m           [0.42352942, 0.30980393, 0.19215687]]], dtype=float32)]
[2m[36m(TrainTrainable pid=15129)[0m  [array([[[1.        , 1.        , 1.        ],
[2m[36m(TrainTrainable pid=15129)[0m           [1.        , 1.        , 0.98039216],
[2m[36m(TrainTrainable pid=15129)[0m           [0.95686275, 0.94509804, 0.91764706],
[2m[36m(TrainTrainable pid=15129)[0m           ...,
[2m[36m(TrainTrainable pid=15129)[0m           [1.        , 1.        , 1.        ],
[2m[36m(TrainTrainable pid=15129)[0m           [1.        , 1.        , 1.        ],
[2m[36m(TrainTrainable pid=15129)[0m           [1.        , 1.        , 1.        ]],
[2m[36m(TrainTrainable pid=15129)[0m 
[2m[36m(TrainTrainable pid=15129)[0m          [[1.        , 1.        , 0.99215686],
[2m[36m(TrainTrainable pid=15129)[0m           [0.81960785, 0.80784315, 0.78039217],
[2m[36m(TrainTrainable pid=15129)[0m           [0.6392157 , 0.6039216 , 0.5647059 ],
[2m

RayTaskError(InvalidArgumentError): [36mray::TrainTrainable.train()[39m (pid=15129, ip=127.0.0.1, repr=TensorflowTrainer)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/trainable.py", line 360, in train
    result = self.step()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/function_runner.py", line 404, in step
    self._report_thread_runner_error(block=True)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/function_runner.py", line 574, in _report_thread_runner_error
    raise e
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/function_runner.py", line 277, in run
    self._entrypoint()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/function_runner.py", line 349, in entrypoint
    return self._trainable_func(
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/ml/trainer.py", line 381, in _trainable_func
    super()._trainable_func(self._merged_config, reporter, checkpoint_dir)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/tune/function_runner.py", line 645, in _trainable_func
    output = fn()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/ml/trainer.py", line 356, in train_func
    trainer.training_loop()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/ml/train/data_parallel_trainer.py", line 354, in training_loop
    for results in training_iterator:
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/trainer.py", line 752, in __next__
    self._final_results = self._run_with_error_handling(
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/trainer.py", line 713, in _run_with_error_handling
    return func()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/trainer.py", line 824, in _finish_training
    return self._backend_executor.finish_training()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/backend.py", line 498, in finish_training
    results = self.get_with_failure_handling(futures)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/backend.py", line 517, in get_with_failure_handling
    success = check_for_failure(remote_values)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/utils.py", line 50, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(InvalidArgumentError): [36mray::BaseWorkerMixin._BaseWorkerMixin__execute()[39m (pid=15153, ip=127.0.0.1, repr=<ray.train.worker_group.BaseWorkerMixin object at 0x19acd4f40>)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/worker_group.py", line 26, in __execute
    return func(*args, **kwargs)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/backend.py", line 489, in end_training
    output = session.finish()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/session.py", line 118, in finish
    func_output = self.training_thread.join()
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/utils.py", line 96, in join
    raise self.exc
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/utils.py", line 89, in run
    self.ret = self._target(*self._args, **self._kwargs)
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/ray/train/utils.py", line 138, in <lambda>
    return lambda: train_func(config)
  File "/var/folders/gx/t32twm6x54dftn9b2wxcbl100000gn/T/ipykernel_13824/3610062019.py", line 29, in train_loop_per_worker
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [[array([[[0.2509804 , 0.2509804 , 0.24313726],
          [0.4627451 , 0.45882353, 0.41960785],
          [0.2509804 , 0.24705882, 0.23921569],
          ...,
          [0.39215687, 0.4117647 , 0.32156864],
          [0.4117647 , 0.42352942, 0.32156864],
          [0.24705882, 0.25882354, 0.22745098]],

         [[0.23529412, 0.23529412, 0.22745098],
          [0.40784314, 0.4       , 0.36078432],
          [0.2784314 , 0.2784314 , 0.25882354],
          ...,
          [0.3137255 , 0.3137255 , 0.27450982],
          [0.42745098, 0.45490196, 0.3372549 ],
          [0.2627451 , 0.2784314 , 0.23529412]],

         [[0.2       , 0.19607843, 0.19215687],
          [0.33333334, 0.3254902 , 0.29411766],
          [0.24705882, 0.24313726, 0.22352941],
          ...,
          [0.2       , 0.1882353 , 0.18431373],
          [0.41960785, 0.44705883, 0.32156864],
          [0.2509804 , 0.27058825, 0.22352941]],

         ...,

         [[0.39215687, 0.40784314, 0.2627451 ],
          [0.30980393, 0.31764707, 0.22745098],
          [0.2901961 , 0.28235295, 0.21176471],
          ...,
          [0.6313726 , 0.47058824, 0.2627451 ],
          [0.56078434, 0.39215687, 0.21960784],
          [0.52156866, 0.39215687, 0.2509804 ]],

         [[0.45490196, 0.42745098, 0.2901961 ],
          [0.29803923, 0.29411766, 0.21568628],
          [0.24705882, 0.23921569, 0.19215687],
          ...,
          [0.5686275 , 0.42745098, 0.23137255],
          [0.46666667, 0.3137255 , 0.18431373],
          [0.4392157 , 0.3019608 , 0.20392157]],

         [[0.39215687, 0.34117648, 0.23921569],
          [0.29411766, 0.25490198, 0.1882353 ],
          [0.22745098, 0.21568628, 0.1882353 ],
          ...,
          [0.67058825, 0.52156866, 0.30588236],
          [0.57254905, 0.42745098, 0.25490198],
          [0.42352942, 0.30980393, 0.19215687]]], dtype=float32)]
 [array([[[1.        , 1.        , 1.        ],
          [1.        , 1.        , 0.98039216],
          [0.95686275, 0.94509804, 0.91764706],
          ...,
          [1.        , 1.        , 1.        ],
          [1.        , 1.        , 1.        ],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 0.99215686],
          [0.81960785, 0.80784315, 0.78039217],
          [0.6392157 , 0.6039216 , 0.5647059 ],
          ...,
          [0.99215686, 0.99607843, 0.99607843],
          [0.99215686, 0.99215686, 0.99215686],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 0.98039216],
          [0.9098039 , 0.9019608 , 0.88235295],
          [0.8980392 , 0.8862745 , 0.8666667 ],
          ...,
          [0.99607843, 1.        , 0.99607843],
          [0.99607843, 0.99607843, 0.99607843],
          [1.        , 1.        , 1.        ]],

         ...,

         [[0.99607843, 1.        , 1.        ],
          [0.99215686, 1.        , 0.99607843],
          [0.99215686, 0.99607843, 0.99607843],
          ...,
          [0.9098039 , 0.89411765, 0.8745098 ],
          [0.91764706, 0.9019608 , 0.88235295],
          [0.91764706, 0.9019608 , 0.88235295]],

         [[1.        , 1.        , 0.99215686],
          [0.99607843, 0.99215686, 0.99215686],
          [0.99607843, 0.99607843, 0.99215686],
          ...,
          [0.99607843, 0.99215686, 0.9882353 ],
          [1.        , 1.        , 0.9882353 ],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 1.        ],
          [0.9882353 , 0.99607843, 0.99607843],
          [0.9843137 , 0.99607843, 0.99607843],
          ...,
          [0.99215686, 0.99607843, 0.9843137 ],
          [0.99215686, 0.99215686, 0.99607843],
          [1.        , 1.        , 1.        ]]], dtype=float32)]].
TypeError: only size-1 arrays can be converted to Python scalars


The above exception was the direct cause of the following exception:


[36mray::BaseWorkerMixin._BaseWorkerMixin__execute()[39m (pid=15153, ip=127.0.0.1, repr=<ray.train.worker_group.BaseWorkerMixin object at 0x19acd4f40>)

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1049, in generator_py_func
    script_ops.FuncRegistry._convert(  # pylint: disable=protected-access

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 229, in _convert
    result = np.asarray(value, dtype=dtype, order="C")

ValueError: setting an array element with a sequence.


During handling of the above exception, another exception occurred:


[36mray::BaseWorkerMixin._BaseWorkerMixin__execute()[39m (pid=15153, ip=127.0.0.1, repr=<ray.train.worker_group.BaseWorkerMixin object at 0x19acd4f40>)

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1053, in generator_py_func
    six.reraise(

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/six.py", line 718, in reraise
    raise value.with_traceback(tb)

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1049, in generator_py_func
    script_ops.FuncRegistry._convert(  # pylint: disable=protected-access

  File "/Users/balaji/GitHub/ray/.venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 229, in _convert
    result = np.asarray(value, dtype=dtype, order="C")

TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [[array([[[0.2509804 , 0.2509804 , 0.24313726],
          [0.4627451 , 0.45882353, 0.41960785],
          [0.2509804 , 0.24705882, 0.23921569],
          ...,
          [0.39215687, 0.4117647 , 0.32156864],
          [0.4117647 , 0.42352942, 0.32156864],
          [0.24705882, 0.25882354, 0.22745098]],

         [[0.23529412, 0.23529412, 0.22745098],
          [0.40784314, 0.4       , 0.36078432],
          [0.2784314 , 0.2784314 , 0.25882354],
          ...,
          [0.3137255 , 0.3137255 , 0.27450982],
          [0.42745098, 0.45490196, 0.3372549 ],
          [0.2627451 , 0.2784314 , 0.23529412]],

         [[0.2       , 0.19607843, 0.19215687],
          [0.33333334, 0.3254902 , 0.29411766],
          [0.24705882, 0.24313726, 0.22352941],
          ...,
          [0.2       , 0.1882353 , 0.18431373],
          [0.41960785, 0.44705883, 0.32156864],
          [0.2509804 , 0.27058825, 0.22352941]],

         ...,

         [[0.39215687, 0.40784314, 0.2627451 ],
          [0.30980393, 0.31764707, 0.22745098],
          [0.2901961 , 0.28235295, 0.21176471],
          ...,
          [0.6313726 , 0.47058824, 0.2627451 ],
          [0.56078434, 0.39215687, 0.21960784],
          [0.52156866, 0.39215687, 0.2509804 ]],

         [[0.45490196, 0.42745098, 0.2901961 ],
          [0.29803923, 0.29411766, 0.21568628],
          [0.24705882, 0.23921569, 0.19215687],
          ...,
          [0.5686275 , 0.42745098, 0.23137255],
          [0.46666667, 0.3137255 , 0.18431373],
          [0.4392157 , 0.3019608 , 0.20392157]],

         [[0.39215687, 0.34117648, 0.23921569],
          [0.29411766, 0.25490198, 0.1882353 ],
          [0.22745098, 0.21568628, 0.1882353 ],
          ...,
          [0.67058825, 0.52156866, 0.30588236],
          [0.57254905, 0.42745098, 0.25490198],
          [0.42352942, 0.30980393, 0.19215687]]], dtype=float32)]
 [array([[[1.        , 1.        , 1.        ],
          [1.        , 1.        , 0.98039216],
          [0.95686275, 0.94509804, 0.91764706],
          ...,
          [1.        , 1.        , 1.        ],
          [1.        , 1.        , 1.        ],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 0.99215686],
          [0.81960785, 0.80784315, 0.78039217],
          [0.6392157 , 0.6039216 , 0.5647059 ],
          ...,
          [0.99215686, 0.99607843, 0.99607843],
          [0.99215686, 0.99215686, 0.99215686],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 0.98039216],
          [0.9098039 , 0.9019608 , 0.88235295],
          [0.8980392 , 0.8862745 , 0.8666667 ],
          ...,
          [0.99607843, 1.        , 0.99607843],
          [0.99607843, 0.99607843, 0.99607843],
          [1.        , 1.        , 1.        ]],

         ...,

         [[0.99607843, 1.        , 1.        ],
          [0.99215686, 1.        , 0.99607843],
          [0.99215686, 0.99607843, 0.99607843],
          ...,
          [0.9098039 , 0.89411765, 0.8745098 ],
          [0.91764706, 0.9019608 , 0.88235295],
          [0.91764706, 0.9019608 , 0.88235295]],

         [[1.        , 1.        , 0.99215686],
          [0.99607843, 0.99215686, 0.99215686],
          [0.99607843, 0.99607843, 0.99215686],
          ...,
          [0.99607843, 0.99215686, 0.9882353 ],
          [1.        , 1.        , 0.9882353 ],
          [1.        , 1.        , 1.        ]],

         [[1.        , 1.        , 1.        ],
          [0.9882353 , 0.99607843, 0.99607843],
          [0.9843137 , 0.99607843, 0.99607843],
          ...,
          [0.99215686, 0.99607843, 0.9843137 ],
          [0.99215686, 0.99215686, 0.99607843],
          [1.        , 1.        , 1.        ]]], dtype=float32)]].


	 [[{{node PyFunc}}]]
	 [[MultiDeviceIteratorGetNextFromShard]]
	 [[RemoteCall]]
	 [[IteratorGetNextAsOptional]] [Op:__inference_train_function_1472]

## Test the network on the test data

In [None]:
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TensorflowPredictor,
    model=Net(),
)
    
outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"], unsqueeze=False
)
outputs.show(1)

# Save checkpoint to file?

## What's next

TODO