# Finetuning a Pytorch Image Classifier with Ray AIR
In this example we will finetune a pretrained ResNet model with Ray Train. 

For this example, our network architecture consists of the intermediate layer output of a pretrained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.


## Load and preprocess finetuning dataset with Ray Data
This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.
We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and we use this only for demenstration purpose. 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np
from typing import Dict
import random

import ray
from ray.data.datasource.partitioning import Partitioning
from ray.train.torch import TorchCheckpoint


In [2]:
import warnings

warnings.filterwarnings("ignore")


The dataset can be downloaded [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip).

In [3]:
! wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
! unzip -o hymenoptera_data.zip

First, we use {meth}`ray.data.read_images <ray.data.read_images>` to load the images. Since the dataset is already structured with directory names as the labels, we can use the {class}`Partitioning <ray.data.datasource.Partitioning>` API to automatically extract image labels.

In [4]:
ray_img_datasets = {}
for split in ["train", "val"]:
    data_folder = f"./hymenoptera_data/{split}"
    partitioning = Partitioning("dir", field_names=["class"], base_dir=data_folder)
    ray_img_datasets[split] = ray.data.read_images(
        data_folder, partitioning=partitioning, mode="RGB"
    )


2023-02-15 14:28:35,638	INFO worker.py:1352 -- Connecting to existing Ray cluster at address: 10.0.37.49:6379...
2023-02-15 14:28:35,671	INFO worker.py:1529 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2023-02-15 14:28:36,353	INFO packaging.py:373 -- Pushing file package 'gcs://_ray_pkg_ffdd1afa7144a1c7df59859eb1ae314b.zip' (226.84MiB) to Ray cluster...
2023-02-15 14:28:40,655	INFO packaging.py:386 -- Successfully pushed file package 'gcs://_ray_pkg_ffdd1afa7144a1c7df59859eb1ae314b.zip'.


We have already load the images from to a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion. Notice that the ResNet model was pretrained with hard-coded normalization values. We'll keep these numbers the same for fine-tuning, as shown in *data_transforms*. More details can be found [here](https://pytorch.org/hub/pytorch_vision_resnet/).

In [5]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

class_to_idx = {"ants": 0, "bees": 1}


def preprocess(batch: Dict[str, np.ndarray], split: str) -> Dict[str, np.ndarray]:
    transform = data_transforms[split]
    batch["image"] = np.array([transform(img).numpy() for img in batch["image"]])
    batch["label"] = np.array([class_to_idx[cls_name] for cls_name in batch["class"]])
    batch.pop("class")
    return batch


ray_datasets = {
    split: ds.map_batches(
        fn=preprocess, fn_kwargs={"split": split}, batch_format="numpy"
    )
    for split, ds in ray_img_datasets.items()
}


Read->Map_Batches: 100%|██████████| 128/128 [00:05<00:00, 22.53it/s]
Read->Map_Batches: 100%|██████████| 128/128 [00:00<00:00, 252.42it/s]


```{note}
Note that **batch** here refers to the chunk of data for preprocessing, not the batch for model training. To learn more about writing functions for {meth}`map_batches <ray.data.Dataset.map_batches>`, read [writing user-defined functions](transform_datasets_writing_udfs).
```

## Initialize Model and Fine-tuning configs

Next let's define our model, you can create a model from a pretrained ResNet, or reload the model checkpoint from a previous run.

In [6]:
def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model


Next, let's define the training configuration, which will be passed into training loop function later.

In [7]:
train_loop_config = {
    "input_size": 224,  # Input image size (224 x 224)
    "batch_size": 32,  # Batch size for training
    "num_epochs": 10,  # Number of epochs to train for
    "lr": 0.001,  # Learning Rate
    "momentum": 0.9,  # SGD optimizer momentum
}


## Define the Training Loop

The `train_loop_per_worker` function defines the finetuning procedure for each worker.

**1. Load dataset shard for each worker**:
- The Trainer will take a dictionary of Ray {class}`~ray.data.Dataset`s as input. These will be preprocessed and accessible in the worker's training loop via {meth}`session.get_dataset_shard() <ray.air.session.get_dataset_shard>`.
- By default, only the dataset under the key "train" will be split into multiple shards. `session.get_dataset_shard()` will return the full dataset for other keys. To configure this, see {class}`~ray.air.DatasetConfig`.
- Use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching and device placement. If you need a more flexible customized batching function, please refer to our lower-level {meth}`iter_batches <ray.data.Dataset.iter_batches>` API.

**2. Prepare your model**:
- {meth}`train.torch.prepare_model() <ray.train.torch.prepare_model>` will prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which will synchronize its weights across all workers.

**3. Report metrics and checkpoint**:
- {meth}`session.report() <ray.air.session.report>` will report metrics and checkpoints to Ray AIR.
- Saving checkpoints through {meth}`session.report(metrics, checkpoint=...) <ray.air.session.report>` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray AIR worker fault tolerance in the future.
- The best checkpoints will be saved according to the specified `checkpoint_score_attribute` in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Here we only save the best model with highest validation accuracy.

In [8]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint


def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


def train_loop_per_worker(configs):
    #     set_random_seed(420)
    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()  # [TODO]

    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(
                batch_size=worker_batch_size, device=device
            )
            for batch in dataset_iterator:
                inputs = batch["image"]
                labels = batch["label"]

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            if session.get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    }
                )
                session.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                    checkpoint=checkpoint,
                )


Next, setup the TorchTrainer:

In [9]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 workers, each assigned 1 CPU and 1 GPU.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

# Save only the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="finetune-resnet",
    local_dir="/tmp/ray_results",
    checkpoint_config=checkpoint_config,
)

# [TODO] You can also initialize a model from previous checkpoint
# CHECKPOINT_URI = "s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000000/"
# def initialize_model_from_ckpt():
#     checkpoint = TorchCheckpoint.from_uri(CHECKPOINT_URI)
#     resnet18 = initialize_model()
#     return checkpoint.get_model(model=resnet18)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)


Training has finished! The best checkpoint has been saved to the experiment directory, and you can now check the experiment metrics and checkpoint information:

In [10]:
result = trainer.fit()


0,1
Current time:,2023-02-15 14:30:01
Running for:,00:01:08.29
Memory:,7.5/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc,_timestamp
TorchTrainer_20605_00000,TERMINATED,10.0.51.23:15947,10,56.2452,0.234332,0.934641,1676500191


(RayTrainWorker pid=16024, ip=10.0.51.23) 2023-02-15 14:28:59,031	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=79304) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=15545, ip=10.0.18.48) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=12110, ip=10.0.26.26) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=16024, ip=10.0.51.23) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
  1%|▏         | 1.45M/97.8M [00:00<00:06, 15.2MB/s]
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
  0%|      

(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 0-train Loss: 1.0319 Acc: 0.3443
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 0-val Loss: 0.5913 Acc: 0.5817


Trial name,_time_this_iter_s,_timestamp,_training_iteration,acc,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
TorchTrainer_20605_00000,4.13338,1676500191,10,0.934641,2023-02-15_14-29-52,True,,17be262490e740238915a0697f5c1ef7,0,ip-10-0-51-23,10,0.234332,10.0.51.23,15947,True,56.2452,4.20876,56.2452,1676500192,0,,10,20605_00000,0.185671


(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 1-train Loss: 0.6640 Acc: 0.5574
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 1-val Loss: 0.4608 Acc: 0.9281
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 2-train Loss: 0.5956 Acc: 0.7213
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 2-val Loss: 0.3758 Acc: 0.8758
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 3-train Loss: 0.5258 Acc: 0.7377
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 3-val Loss: 0.3247 Acc: 0.9477
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 4-train Loss: 0.4128 Acc: 0.7869
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 4-val Loss: 0.2768 Acc: 0.9412
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 5-train Loss: 0.3988 Acc: 0.7541
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 5-val Loss: 0.2693 Acc: 0.9542




(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 6-train Loss: 0.3330 Acc: 0.7869
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 6-val Loss: 0.2469 Acc: 0.9542




(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 7-train Loss: 0.3207 Acc: 0.7541
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 7-val Loss: 0.2467 Acc: 0.9412




(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 8-train Loss: 0.2849 Acc: 0.8197
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 8-val Loss: 0.2342 Acc: 0.9412




(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 9-train Loss: 0.2719 Acc: 0.8197
(RayTrainWorker pid=16024, ip=10.0.51.23) Epoch 9-val Loss: 0.2343 Acc: 0.9346


2023-02-15 14:30:01,519	INFO tune.py:762 -- Total run time: 68.63 seconds (68.29 seconds for the tuning loop).


## Load the fine-tuned model for batch prediction

Now, we want to load the trained model and evaluation it on test data.
We can use `TorchCheckpoint.from_uri()` to load the resulting checkpoint from our fine-tuning run. The {class}`~ray.train.batch_predictor.BatchPredictor` will identify the dict key `"model"` and load the corresponding parameters into the model.
 

Example path for a checkpoint folder:
`"/tmp/ray_results/finetune-resnet/TorchTrainer_94bb5_00000_0_2023-02-14_14-40-28/checkpoint_000009"`

In [11]:
checkpoint_folder = result.checkpoint.uri


In [12]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

ckpt = TorchCheckpoint.from_uri(checkpoint_folder)
predictor = BatchPredictor.from_checkpoint(
    ckpt, TorchPredictor, model=initialize_model()
)


In [13]:
prediction_ds = predictor.predict(
    ray_datasets["val"],
    feature_columns=["image"],
    keep_columns=["label"],
    num_gpus_per_worker=1,
)


2023-02-15 14:30:03,395	INFO batch_predictor.py:184 -- `num_gpus_per_worker` is set for `BatchPreditor`.Automatically enabling GPU prediction for this predictor. To disable set `use_gpu` to `False` in `BatchPredictor.predict`.
Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:10<00:00, 10.38s/it]


## Evaluate prediction results

The prediction has finished! We can use `ds.schema()` and `ds.take()` to inspect the data types and record structure.

We can see that there are two keys in the prediction results:
- "predictions": The output logits of our ResNet model, which is a 1000 dimensional tensor.
- "label": The image label. Specified by `keep_columns` in `predictor.predict()`.

In [14]:
prediction_ds.schema()


predictions: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64

In [15]:
prediction_ds.take(1)


[ArrowRow({'predictions': array([ 0.9406432 , -0.83088267], dtype=float32),
           'label': 0})]

Here we define a function `convert_logits_to_classes` to convert tensor outputs to labels. 

In [16]:
import pandas as pd


def convert_logits_to_classes(batch):
    batch["pred_label"] = np.argmax(batch["predictions"], axis=1)
    batch["correct"] = batch["pred_label"] == batch["label"]
    return batch


predictions = prediction_ds.map_batches(convert_logits_to_classes, batch_format="numpy")
predictions.show(1)

print("Evaluation Accuracy = ", predictions.mean(on="correct"))


Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 75.91it/s]


{'predictions': array([ 0.9406432 , -0.83088267], dtype=float32), 'label': 0, 'pred_label': 0, 'correct': True}


Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 118.18it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 129.44it/s]


Evaluation Accuracy =  0.934640522875817


You can also reuse the evaluation function defined in the training loop by iterating over the dataset. Note that the previous approach using `map_batches()` is more efficient because it parallelizes the evaluation on each partition.

In [17]:
def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


accuracy = 0
for batch in prediction_ds.iter_torch_batches(batch_size=10):
    accuracy += evaluate(batch["predictions"], batch["label"])
accuracy /= prediction_ds.count()

print("Evaluation Accuracy = ", accuracy)


Evaluation Accuracy =  0.934640522875817
