# Ingesting image datasets for training with PyTorch

This example demonstrates how to work with image training datasets using [Ray Data](data) on a single node.

In this example, we'll use Ray Data to:
1. Load and preprocess raw images for training.
2. Load preprocessed images into PyTorch using [Ray Train](train) to train an object classification model.
3. Write preprocessed images back to cloud storage (such as Amazon S3) to improve ingest throughput.


## Before You Begin

Install the following dependencies if you haven't already.

In [None]:
!pip install "ray[data]" torchvision awscli

Let's first download an image dataset to use. If you already have an image dataset on local disk or in S3, you can skip the rest of this section.
Otherwise, follow on to download a dataset from Kaggle.
For this example, we'll use bash to download a [4GB subset of the ImageNet dataset](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000) and (optionally) upload it to an S3 bucket.

```bash
$ pip install kaggle awscli
$ kaggle datasets list -s ifigotin/imagenetmini-1000
# ref                         title                 size  lastUpdated          downloadCount  voteCount  usabilityRating  
# --------------------------  --------------------  ----  -------------------  -------------  ---------  ---------------  
# ifigotin/imagenetmini-1000  ImageNet 1000 (mini)   4GB  2020-03-10 01:05:11          11779        133  0.375
$ kaggle datasets download ifigotin/imagenetmini-1000
$ unzip imagenetmini-1000.zip -d /tmp
```

The directory structure should look like this:
```bash
$ tree /tmp/imagenet-mini
# ── train
# │   ├── n01440764
# │   │   ├── n01440764_10043.JPEG
# │   │   ├── n01440764_10470.JPEG
# │   │   ├── n01440764_10744.JPEG
# ...
# └── val
#     ├── n01440764
#     │   ├── ILSVRC2012_val_00009111.JPEG
#     │   ├── ILSVRC2012_val_00030740.JPEG
# ...
```

If you want to read from S3, run the following commands and pass the S3 URI `s3://imagenetmini-1000/train` in the following Python code.
Otherwise, you can pass the local path `/tmp/imagenet-mini/train`.
```bash
$ aws s3 mb s3://imagenetmini-1000
$ aws s3 sync /tmp/imagenet-mini s3://imagenetmini-1000
$ aws s3 ls s3://imagenetmini-1000/train/
#                           PRE n01440764/
#                           PRE n01443537/
#                           PRE n01484850/
#                           ...
```

## Loading the training dataset

Let's start by loading the training dataset and examining the data.
To speed things up, we'll first download the dataset from S3 to local disk.
If you're using a multi-node cluster, you can also just read directly from S3.


In [None]:
!aws s3 sync s3://imagenetmini-1000/ /tmp/imagenetmini-1000

We can use {meth}`~ray.data.Dataset.take_batch` to produce a single batch of data.
Later, we'll apply preprocessing transforms to the images, one batch at a time.

In [None]:
import ray
from PIL import Image


TRAIN_DATASET_URL = "/tmp/imagenetmini-1000/train"
# Uncomment to read directly from S3.
# This is recommended if you are using multiple nodes for training
# or if you do not have enough disk space to hold the entire dataset.
# TRAIN_DATASET_URL = "s3://imagenetmini-1000/train"
ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB")

batch = ds.take_batch(batch_size=2)
# print(batch)
# {'image': array([array([[[110, 117, 123],
#                [111, 118, 124],
#                [113, 118, 124],
#                ...,
#                [164, 191, 210]]], dtype=uint8)], dtype=object)}

def show(image_array):
    image = Image.fromarray(image_array)
    display(image)    

show(batch["image"][0])

When we pass the `Dataset` to Ray Train, we will receive a {class}`ray.data.DataIterator` on each training worker.
Let's use that API to iterate through the dataset here.
We'll limit the dataset to 1000 rows to shorten the execution and we'll use {meth}`ray.data.DataIterator.iter_torch_batches` since we'll be using a PyTorch model later on.

Try commenting back in the code below and see what happens.

In [None]:
# This is the interface that will be used by Ray Train workers.
it = ds.limit(1000).iterator()

# for torch_batch in it.iter_torch_batches(batch_size=32):
#     print(torch_batch)

Uh-oh, did you get an error above?

You should have seen an error like this:
```python
RuntimeError: Numpy array of object dtype cannot be converted to a Torch Tensor. This may because the numpy array is a ragged tensor--it contains items of different sizes. If using `iter_torch_batches()` API, you can pass in a `collate_fn` argument to specify custom logic to convert the Numpy array batch to a Torch tensor batch.
```

This happens because each of our images is a different size. Later we'll crop all of the images to the same size, so we can revisit this then.

## Extracting labels from the pathname

We also want our dataset to include the label for each image.
For images where the subdirectory name is also the class name, we can use a {class}`~ray.data.datasource.Partitioning` to extract the class name from each filename and attach it as an additional field to each row.

In [None]:
from ray.data.datasource.partitioning import Partitioning

ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB", partitioning=Partitioning("dir", field_names=["class"], base_dir=TRAIN_DATASET_URL))
batch = ds.take_batch(batch_size=2)
print(batch["class"])
# ['n01608432' 'n01608432']

If you have a different directory structure, then you can instead attach each image's filename using `include_paths=True` and add a custom {meth}`~ray.data.Dataset.map_batches` call to extract the class label.

In [None]:
from ray.data.datasource.partitioning import Partitioning
import os

ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB", include_paths=True)
batch = ds.take_batch(batch_size=2)
print(batch["path"])
# array(['/tmp/imagenet-mini/train/n01560419/n01560419_2363.JPEG',
#        '/tmp/imagenet-mini/train/n01560419/n01560419_2368.JPEG'],
#       dtype=object)

def extract_class_from_basename(batch):
    batch["class"] = [os.path.basename(path).split("_")[0] for path in batch["path"]]
    return batch

ds = ds.map_batches(extract_class_from_basename)
batch = ds.take_batch(batch_size=2)
print(batch["class"])
# array(['n01530575', 'n01530575'], dtype=object)

To extract the class as an integer, we can do another `map_batches` call to convert the string to an integer.

If your class names don't already have an integer in them, we can also assign each class name an integer. To do this, we first call {meth}`~ray.data.Dataset.unique` to get the unique label values and assign each value an integer.

In [None]:
ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB", partitioning=Partitioning("dir", field_names=["class"], base_dir=TRAIN_DATASET_URL))
# Create a dict mapping from class_name -> integer.
classes = ds.unique(column="class")
classes_to_idx = {class_name: idx for idx, class_name in enumerate(classes)}

Next, we call `map_batches` again to convert each image's class from a string to an integer.

In [None]:
# Convert the string class names to integers in the dataset.
def convert_class_to_idx(batch, classes_to_idx):
    batch["class"] = [classes_to_idx[class_name] for class_name in batch["class"]]
    return batch

ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB", partitioning=Partitioning("dir", field_names=["class"], base_dir=TRAIN_DATASET_URL))
ds = ds.map_batches(convert_class_to_idx, fn_kwargs={"classes_to_idx": classes_to_idx})
batch = ds.take_batch(batch_size=2)
print(batch["class"])
# [0 0]

## Preprocessing images

Next we'll apply some preprocessing transforms to our images.
Let's use [torchvision](https://pytorch.org/vision/stable/index.html) to define a preprocessor function that randomly crops and flips a batch of images, then returns the image as a `torch.Tensor`.
This code matches the spec for the [MLPerf image classification benchmark](https://github.com/mlcommons/training/tree/master/image_classification).

In [None]:
import torch
import torchvision
import numpy as np

DEFAULT_IMAGE_SIZE = 224

def crop_and_flip_image_batch(image_batch):
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomResizedCrop(
                size=DEFAULT_IMAGE_SIZE,
                scale=(0.05, 1.0),
                ratio=(0.75, 1.33),
            ),
            torchvision.transforms.RandomHorizontalFlip(),
        ]
    )
    
    # Tip: Use ray.data.read_images(size=(H, W)) to resize all images
    # during loading. Then, you can use vectorized transforms here,
    # shown below, instead of applying the transform one image at a time.
    # tensor_batch = torch.Tensor(np.transpose(image_batch["batch"], axes=(0, 3, 1, 2))
    # image_batch["image"] = transform(tensor_batch)

    def crop_and_flip_image(image):
        # Transpose to match torchvision's expected shape.
        # (height, width, channels) -> (channels, height, width).
        return transform(torch.Tensor(np.transpose(image, axes=(2, 0, 1))))
    
    image_batch["image"] = [crop_and_flip_image(image) for image in image_batch["image"]]
    return image_batch

First, we can try the function out manually on one of the batches that we produced earlier.

In [None]:
import numpy as np

print("Before")
batch = ds.take_batch(batch_size=2)
show(batch["image"][0])

batch = crop_and_flip_image_batch(batch)
print("After")
cropped_and_flipped_img = np.transpose(
    np.array(batch["image"][0], dtype=np.uint8),
    axes=(1, 2, 0))
show(cropped_and_flipped_img)

Now let's apply the preprocessor to the entire dataset.

In [None]:
ds = ray.data.read_images(TRAIN_DATASET_URL, mode="RGB", partitioning=Partitioning("dir", field_names=["class"], base_dir=TRAIN_DATASET_URL))
ds = ds.map_batches(convert_class_to_idx, fn_kwargs={"classes_to_idx": classes_to_idx})
ds = ds.map_batches(crop_and_flip_image_batch)

Now that all of the images have the same dimensions, we're ready to iterate using {meth}`ray.data.DataIterator.iter_torch_batches`.

In [None]:
# This is the interface that will be used by Ray Train workers.
it = ds.limit(1000).iterator()

num_rows_read = 0
for torch_batch in it.iter_torch_batches(batch_size=32):
    num_rows_read += len(torch_batch["image"])
print(f"Read {num_rows_read} rows.")

Note that because we're using a randomized preprocessor, each time we call `DataIterator.iter_torch_batches`, it will produce different results.
This can be helpful during training, if you want to randomly apply transforms during each epoch.