## 05 · Define Image Transformation  

To make the dataset usable by PyTorch, we need to preprocess the raw image arrays with the same steps that pytorch data loader does.  

- Define a function `transform_images(row)` that:  
  * Converts the `"image"` array from `numpy` into a PIL image.  
  * Applies the standard PyTorch transforms:  
    - `ToTensor()` → converts the image to a tensor.  
    - `Normalize((0.5,), (0.5,))` → scales pixel values to the range [-1, 1].  
  * Replaces the `"image"` entry in the row with the transformed tensor.  

This function will later be applied in parallel across the Ray Dataset.  


In [None]:
# 05. Define preprocessing transform for Ray Data

def transform_images(row: dict):
    # Convert numpy array to a PIL image, then apply TorchVision transforms
    transform = Compose([
        ToTensor(),              # convert to tensor
        Normalize((0.5,), (0.5,)) # normalize to [-1, 1]
    ])

    # Ensure image is in uint8 before conversion
    image_arr = np.array(row["image"], dtype=np.uint8)

    # Apply transforms and replace the "image" field with tensor
    row["image"] = transform(Image.fromarray(image_arr))
    return row

<div class="alert alert-block alert-info">

**Note**: Unlike the PyTorch DataLoader, the preprocessing can now occur on any node in the cluster.

The data will be passed to training workers via the ray object store (a distributed in-memory object store).

<div>

### 06 · Apply Transformations with Ray Data  

Now we apply the preprocessing function to the dataset using `map()`:  

- `train_ds.map(transform_images)` → runs the `transform_images` function on every row of the dataset.  
- Transformations are executed **in parallel across the cluster**, so preprocessing can scale independently of training.  
- The transformed dataset now has:  
  * `"image"` → normalized PyTorch tensors  
  * `"label"` → unchanged integer labels  

This makes the dataset ready to be streamed into the training loop.  

In [None]:
# 06. Apply the preprocessing transform across the Ray Dataset

# Run transform_images() on each row (parallelized across cluster workers)
train_ds = train_ds.map(transform_images)