# Finetuning a Pytorch Image Classifier with Ray AIR
In this example we will finetune a pretrained ResNet model with Ray Train. 

For this example, our network architecture consists of the intermediate layer output of a pretrained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.


## Load and preprocess finetuning dataset with Ray Data
This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.
We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and we use this only for demenstration purpose. 

In [13]:
# If you want to run full example, please set this to False
SMOKE_TEST = True
if SMOKE_TEST:
    import os
    os.system("pip install -U moto[s3,server]==2.4.1")


In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from typing import Dict
import numpy as np
import warnings

warnings.filterwarnings("ignore")

import ray
from ray.data.datasource.partitioning import Partitioning
from ray.train.torch import TorchCheckpoint


The dataset can be downloaded [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip).

First, we use {meth}`ray.data.read_images <ray.data.read_images>` to load the images. Since the dataset is already structured with directory names as the labels, we can use the {class}`Partitioning <ray.data.datasource.Partitioning>` API to automatically extract image labels.

In [2]:
ray_img_datasets = {}
for split in ["train", "val"]:
    data_folder = f"s3://anonymous@air-example-data-2/hymenoptera_data/{split}"
    partitioning = Partitioning("dir", field_names=["class"], base_dir=data_folder)
    ray_img_datasets[split] = ray.data.read_images(
        data_folder, size=(256, 256), partitioning=partitioning, mode="RGB"
    )


2023-02-22 10:40:25,388	INFO worker.py:1360 -- Connecting to existing Ray cluster at address: 10.0.62.233:6379...
2023-02-22 10:40:25,422	INFO worker.py:1548 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2023-02-22 10:40:25,426	INFO packaging.py:330 -- Pushing file package 'gcs://_ray_pkg_a3a790f8196b6a46eed365be4d734a9d.zip' (0.52MiB) to Ray cluster...
2023-02-22 10:40:25,432	INFO packaging.py:343 -- Successfully pushed file package 'gcs://_ray_pkg_a3a790f8196b6a46eed365be4d734a9d.zip'.


We have already load the images from to a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion. Notice that the ResNet model was pretrained with hard-coded normalization values. We'll keep these numbers the same for fine-tuning, as shown in *data_transforms*. More details can be found [here](https://pytorch.org/hub/pytorch_vision_resnet/).

Define a label preprocessor with {class}`BatchMapper <ray.data.preprocessors.BatchMapper>`:

In [20]:
from ray.data.preprocessors import BatchMapper

class_to_idx = {"ants": 0, "bees": 1}

# 1. Map the image folder names to label ids
def map_labels(batch: np.ndarray) -> np.ndarray:
    batch["label"] = np.vectorize(class_to_idx.__getitem__)(batch["class"])
    batch.pop("class")
    return batch


label_preprocessor = BatchMapper(fn=map_labels, batch_format="numpy")


```{note}
Note that **batch** here refers to the chunk of data for preprocessing, not the batch for model training. To learn more about writing functions for mapping batches, read [writing user-defined functions](transform_datasets_writing_udfs).
```

Next, define an image preprocessor with {class}`TorchVisionPreprocessor <ray.data.preprocessors.TorchVisionPreprocessor>`:

In [21]:
from ray.data.preprocessors import TorchVisionPreprocessor

# 2. Convert input image to tensors
def to_tensor(batch: np.ndarray) -> torch.Tensor:
    tensor = torch.as_tensor(batch, dtype=torch.float)
    # (B, H, W, C) -> (B, C, H, W)
    tensor = tensor.permute(0, 3, 1, 2).contiguous()
    # [0., 255.] -> [0., 1.]
    tensor = tensor.div(255)
    return tensor


# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Lambda(to_tensor),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Lambda(to_tensor),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

# Accelerate image processing with batched transformations
image_preprocessors = {
    split: TorchVisionPreprocessor(
        columns=["image"], transform=data_transforms[split], batched=True
    )
    for split in ["train", "val"]
}


Combine two preprocessors with {class}`Chain <ray.data.preprocessors.Chain>` and transform raw datasets:

In [22]:
from ray.data.preprocessors import Chain

ray_datasets = {}
for split in ["train", "val"]:
    preprocessor = Chain(image_preprocessors[split], label_preprocessor)
    ray_datasets[split] = preprocessor.fit_transform(ray_img_datasets[split])


2023-02-22 10:53:10,874	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[read->TorchVisionPreprocessor]
read->TorchVisionPreprocessor: 100%|██████████| 128/128 [00:03<00:00, 37.63it/s]
2023-02-22 10:53:14,346	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper]
BatchMapper: 100%|██████████| 128/128 [00:00<00:00, 4499.57it/s]
2023-02-22 10:53:14,518	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[read->TorchVisionPreprocessor]
read->TorchVisionPreprocessor: 100%|██████████| 128/128 [00:03<00:00, 41.15it/s]
2023-02-22 10:53:17,699	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper]
BatchMapper: 100%|██████████| 128/128 [00:00<00:00, 3793.20it/s]


## Initialize Model and Fine-tuning configs

Next let's define our model, you can create a model from a pretrained ResNet, or reload the model checkpoint from a previous run.

In [9]:
def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model


Next, let's define the training configuration, which will be passed into training loop function later.

In [10]:
train_loop_config = {
    "input_size": 224,  # Input image size (224 x 224)
    "batch_size": 32,  # Batch size for training
    "num_epochs": 10,  # Number of epochs to train for
    "lr": 0.001,  # Learning Rate
    "momentum": 0.9,  # SGD optimizer momentum
}


## Define the Training Loop

The `train_loop_per_worker` function defines the finetuning procedure for each worker.

**1. Load dataset shard for each worker**:
- The Trainer will take a dictionary of Ray {class}`~ray.data.Dataset`s as input. These will be preprocessed and accessible in the worker's training loop via {meth}`session.get_dataset_shard() <ray.air.session.get_dataset_shard>`.
- By default, only the dataset under the key "train" will be split into multiple shards. `session.get_dataset_shard()` will return the full dataset for other keys. To configure this, see {class}`~ray.air.DatasetConfig`.
- Use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching and device placement. If you need a more flexible customized batching function, please refer to our lower-level {meth}`iter_batches <ray.data.Dataset.iter_batches>` API.

**2. Prepare your model**:
- {meth}`train.torch.prepare_model() <ray.train.torch.prepare_model>` will prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which will synchronize its weights across all workers.

**3. Report metrics and checkpoint**:
- {meth}`session.report() <ray.air.session.report>` will report metrics and checkpoints to Ray AIR.
- Saving checkpoints through {meth}`session.report(metrics, checkpoint=...) <ray.air.session.report>` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray AIR worker fault tolerance in the future.
- The best checkpoints will be saved according to the specified `checkpoint_score_attribute` in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Here we only save the best model with highest validation accuracy.

In [11]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint


def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


def train_loop_per_worker(configs):
    warnings.filterwarnings("ignore")
    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()  # [TODO]

    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(
                batch_size=worker_batch_size, device=device
            )
            for batch in dataset_iterator:
                inputs = batch["image"]
                labels = batch["label"]

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            if session.get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    }
                )
                session.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                    checkpoint=checkpoint,
                )


Next, setup the TorchTrainer:

In [14]:
# @title +
if SMOKE_TEST:
    from moto.server import ThreadedMotoServer
    from moto import mock_s3
    import boto3
    import os
    import logging

    server = ThreadedMotoServer(port=5002)
    server.start()

    s3 = boto3.client("s3", endpoint_url="http://localhost:5002")

    bucket_name = "checkpoint-bucket"
    s3.create_bucket(Bucket=bucket_name)
    logging.getLogger("werkzeug").setLevel(logging.WARNING)


In [15]:
UPLOAD_DIR = "s3://YOUR_BUCKET_NAME"


In [None]:
# @title +
if SMOKE_TEST:
    UPLOAD_DIR = f"s3://{bucket_name}/results?endpoint_override=http://localhost:5002"


In [17]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 workers, each assigned 1 CPU and 1 GPU.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

# Save only the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="finetune-resnet",
    sync_config=SyncConfig(upload_dir=UPLOAD_DIR),
    checkpoint_config=checkpoint_config,
)


In [18]:
# @title +
if SMOKE_TEST:
    scaling_config = ScalingConfig(
        num_workers=8, use_gpu=False, resources_per_worker={"CPU": 1}
    )
    train_loop_config["num_epochs"] = 1


In [23]:
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

result = trainer.fit()


0,1
Current time:,2023-02-22 10:54:34
Running for:,00:01:03.44
Memory:,9.4/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc
TorchTrainer_3303a_00000,TERMINATED,10.0.56.151:3546,10,54.9253,0.168786,0.928105


(RayTrainWorker pid=3633, ip=10.0.56.151) 2023-02-22 10:53:37,556	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=3546, ip=10.0.56.151) 2023-02-22 10:53:38,162	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(RayTrainWorker pid=52252) 2023-02-22 10:53:40,867	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=3767, ip=10.0.18.83) 2023-02-22 10:53:40,877	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=3633, ip=10.0.56.151) 2023-02-22 10:53:40,882	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=3572, ip=10.0.53.242) 2023-02-22 10:53:40,869	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=3572, ip=10.0.53.242) 2023-02-22 10:53:42,522	INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=52252) 2023

(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 0-train Loss: 0.6499 Acc: 0.6721


(RayTrainWorker pid=3633, ip=10.0.56.151) 2023-02-22 10:53:46,178	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(RayTrainWorker pid=3572, ip=10.0.53.242) 2023-02-22 10:53:46,124	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]


(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 0-val Loss: 0.5461 Acc: 0.7059


Trial name,acc,date,done,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
TorchTrainer_3303a_00000,0.928105,2023-02-22_10-54-29,True,0,ip-10-0-56-151,10,0.168786,10.0.56.151,3546,True,54.9253,4.22899,54.9253,1677092067,10,3303a_00000


(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 1-train Loss: 0.4958 Acc: 0.7869
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 1-val Loss: 0.3475 Acc: 0.9281




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 2-train Loss: 0.3287 Acc: 0.8852
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 2-val Loss: 0.2657 Acc: 0.9346




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 3-train Loss: 0.2355 Acc: 0.9344
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 3-val Loss: 0.2206 Acc: 0.9346




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 4-train Loss: 0.1555 Acc: 0.9672
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 4-val Loss: 0.1965 Acc: 0.9281




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 5-train Loss: 0.1060 Acc: 0.9836
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 5-val Loss: 0.1854 Acc: 0.9346




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 6-train Loss: 0.0758 Acc: 1.0000
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 6-val Loss: 0.1783 Acc: 0.9346




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 7-train Loss: 0.0546 Acc: 1.0000
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 7-val Loss: 0.1729 Acc: 0.9281




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 8-train Loss: 0.0415 Acc: 1.0000
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 8-val Loss: 0.1703 Acc: 0.9281




(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 9-train Loss: 0.0333 Acc: 1.0000
(RayTrainWorker pid=3633, ip=10.0.56.151) Epoch 9-val Loss: 0.1688 Acc: 0.9281


2023-02-22 10:54:36,004	INFO tune.py:825 -- Total run time: 65.35 seconds (63.44 seconds for the tuning loop).


Training has finished! The best checkpoint has been saved to the experiment directory, and you can now check the experiment metrics and checkpoint information:

## Load the fine-tuned model for batch prediction

Now, we want to load the trained model and evaluation it on test data.
We can use {ref}`TorchCheckpoint.from_directory() <ray.train.torch.TorchCheckpoint.from_directory>` to load the resulting checkpoint from our fine-tuning run. The {class}`~ray.train.batch_predictor.BatchPredictor` will identify the dict key `"model"` and load the corresponding parameters into the model.
 

The log and checkpoints will be saved into `upload_dir` specified in TrainerTrainer. 

For example:
"s3://YOUR_BUCKET_NAME/finetune-resnet/TorchTrainer_94bb5_00000_0_2023-02-14_14-40-28/checkpoint_000009"

In [25]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

ckpt = TorchCheckpoint.from_uri(result.checkpoint.uri)
predictor = BatchPredictor.from_checkpoint(
    ckpt, TorchPredictor, model=initialize_model()
)


In [None]:
prediction_ds = predictor.predict(
    ray_datasets["val"],
    feature_columns=["image"],
    keep_columns=["label"],
    num_gpus_per_worker=1,
)


## Evaluate prediction results

The prediction has finished! We can use `ds.schema()` and `ds.take()` to inspect the data types and record structure.

We can see that there are two keys in the prediction results:
- "predictions": The output logits of our ResNet model, which is a 1000 dimensional tensor.
- "label": The image label. Specified by `keep_columns` in `predictor.predict()`.

In [27]:
prediction_ds.schema()


predictions: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64

In [28]:
prediction_ds.take(1)


[{'predictions': array([ 0.37351528, -0.7467256 ], dtype=float32), 'label': 0}]

Here we define a function `convert_logits_to_classes` to convert tensor outputs to labels. 

In [29]:
import pandas as pd


def convert_logits_to_classes(batch):
    batch["pred_label"] = np.argmax(batch["predictions"], axis=1)
    batch["correct"] = batch["pred_label"] == batch["label"]
    return batch


predictions = prediction_ds.map_batches(convert_logits_to_classes, batch_format="numpy")
predictions.show(1)

print("Evaluation Accuracy = ", predictions.mean(on="correct"))


2023-02-22 10:57:15,005	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(convert_logits_to_classes)]
MapBatches(convert_logits_to_classes): 100%|██████████| 1/1 [00:00<00:00, 243.40it/s]


{'predictions': array([ 0.37351528, -0.7467256 ], dtype=float32), 'label': 0, 'pred_label': 0, 'correct': True}


2023-02-22 10:57:15,026	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[aggregate]
Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 138.26it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 121.42it/s]


Evaluation Accuracy =  0.9281045751633987


You can also reuse the evaluation function defined in the training loop by iterating over the dataset. Note that the previous approach using `map_batches()` is more efficient because it parallelizes the evaluation on each partition.

In [30]:
def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


accuracy = 0
for batch in prediction_ds.iter_torch_batches(batch_size=10):
    accuracy += evaluate(batch["predictions"], batch["label"])
accuracy /= prediction_ds.count()

print("Evaluation Accuracy = ", accuracy)


Evaluation Accuracy =  0.9281045751633987


In [None]:
# @title +
if SMOKE_TEST:
    server.stop()
