# Finetuning a Pytorch ResNet Model for Image Classification
In this example we will finetune a pretrained ResNet model with Ray Train. You should be familiar with [PyTorch](https://pytorch.org/) before starting the tutorial. 

For fine-tuning, our network architecture consists of a pretrained ResNet model as the backbone and a randomly initialized linear layer as the classifier. The ResNet model is pretrained on the 1000-class Imagenet dataset. We will unfreeze and retrain all parameters of the model for the new task.




# Load and transform datasets
We will use the *hymenoptera_data* as the fintuning dataset, which contains two classes(bees and ants) and 397 images(244 for train, 153 for validation). The dataset is provided by Pytorch and can be downloaded [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip). The dataset folder is structured such that we can load with Pytorch [ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) dataset.

Notice that the ResNet model was pretrained with hard-coded normalization values. We'll keep these numbers the same for fine-tuning, as shown in *data_transforms*. More details can be found [here](https://pytorch.org/hub/pytorch_vision_resnet/).

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np

# Replace with your own path of the dataset
DATA_DIR = "/mnt/cluster_storage/hymenoptera_data"

# Data augmentation and normalization for training
# Just normalization for validation
input_size = 224
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

torch_datasets = dict()
for split in ["train", "val"]:
    torch_datasets[split] = datasets.ImageFolder(os.path.join(DATA_DIR, split), data_transforms[split])

Next we will transform our ImageFolder dataset into a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion.

Note that **batch** here refers to the chunk of data that the map function will execute on, not the batch we use for model training. To learn more about writing functions for {meth}`map_batches <ray.data.Dataset.map_batches>`, read [writing user-defined functions](https://docs.ray.io/en/latest/data/transforming-datasets.html#transform-datasets-writing-udfs) for more details.

In [2]:
import ray

def convert_batch_to_numpy(batch):
    images = np.array([image.numpy() for image, _ in batch])
    labels = np.array([label for _, label in batch])
    return {"image": images, "label": labels}

ray_datasets = dict()
for split in ["train", "val"]:
    ray_datasets[split] = ray.data.from_torch(torch_datasets[split]).map_batches(convert_batch_to_numpy)
    print(ray_datasets[split].schema())

2023-02-05 18:14:51,368	INFO worker.py:1352 -- Connecting to existing Ray cluster at address: 10.0.7.113:6379...
2023-02-05 18:14:51,376	INFO worker.py:1529 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2023-02-05 18:14:51,380	INFO packaging.py:373 -- Pushing file package 'gcs://_ray_pkg_e5286d4376f7908c7b2027efcd21fa25.zip' (0.13MiB) to Ray cluster...
2023-02-05 18:14:51,381	INFO packaging.py:386 -- Successfully pushed file package 'gcs://_ray_pkg_e5286d4376f7908c7b2027efcd21fa25.zip'.
Map_Batches:   0%|          | 0/244 [00:00<?, ?it/s]

(scheduler +5s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.


Map_Batches: 100%|██████████| 244/244 [00:04<00:00, 49.58it/s] 


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


Map_Batches: 100%|██████████| 153/153 [00:00<00:00, 307.94it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


## Initialize Model and Fine-tuning configs

In [3]:
def initialize_model(num_classes):
    # Load pretrained model params
    model = models.resnet18(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model

In [4]:
configs = dict()

# Input image size (224 x 224)
configs["input_size"] = 224

# Number of label classes
configs["num_classes"] = 2

# Batch size for training (change depending on how much memory you have)
configs["batch_size"] = 8

# Number of epochs to train for
configs["num_epochs"] = 15

# Hyper-parameters for optimizer
configs["lr"] = 0.001
configs["momentum"] = 0.9

## Define and Run Training Loop

The `train_loop_per_worker` function handles the training and validation of a given model.
1. Load dataset shard for each worker:
- A ray trainer will take a dictionary of ray datasets as input, the one denoted by the "train" key will be automatically be split into multiple dataset shards that can then be accessed by `session.get_dataset_shard("train")`. All other datasets will not be split.
- One can use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching. If you need more flexible customized batching function, please refer to {meth}`iter_batches <ray.data.Dataset.iter_batches>`, which is a lower-level API.
2. Prepare your model:
- `train.torch.prepare_model` will prepares the model for distributed execution. It will transform your model into DistributedDataParallel under the hood to synchronize gradients and buffers.
3. Report metrics and checkpoint:
- `session.report` will gather the metrics from each worker and save into log files.
- The best checkpoints will be saved according to the reported metrics specified in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`.

In [5]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint

def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

def train_loop_per_worker(configs):
    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model(num_classes=configs["num_classes"])
    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(model.parameters(), lr=configs["lr"], momentum=configs["momentum"])
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(batch_size=worker_batch_size)
            for batch in dataset_iterator:
                inputs = batch["image"].to(device)
                labels = batch["label"].to(device)
            
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, phase, epoch_loss, epoch_acc))

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()
                    }
                )
                session.report(metrics={"loss": epoch_loss, "acc": epoch_acc}, checkpoint=checkpoint)


Next, setup the TorchTrainer:

In [6]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True, resources_per_worker={"CPU": 4, "GPU": 1})

# Save the best checkpoint with highest validation accuracy 
checkpoint_config = CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="acc", checkpoint_score_order="max")

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="resnet-finetune",
    local_dir="/mnt/cluster_storage/ray_results",  # Use shared filesystem for checkpointing, no sync required
    sync_config=SyncConfig(syncer=None),
    checkpoint_config=checkpoint_config
)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=configs,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

In [7]:
result = trainer.fit()

0,1
Current time:,2023-02-05 18:15:51
Running for:,00:00:52.27
Memory:,10.3/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc,_timestamp
TorchTrainer_0e7c2_00000,TERMINATED,10.0.34.113:14558,15,46.7192,0.236638,0.921569,1675649747


(RayTrainWorker pid=14635, ip=10.0.34.113) 2023-02-05 18:15:04,714	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=14635, ip=10.0.34.113) 2023-02-05 18:15:05,366	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=18732) 2023-02-05 18:15:05,367	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=18732) 2023-02-05 18:15:06,975	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=14635, ip=10.0.34.113) 2023-02-05 18:15:06,985	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.


(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 0-train Loss: 0.5568 Acc: 0.6803
(RayTrainWorker pid=18732) Epoch 0-train Loss: 0.7022 Acc: 0.6066
(RayTrainWorker pid=18732) Epoch 0-val Loss: 0.2785 Acc: 0.8954
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 0-val Loss: 0.2785 Acc: 0.8954


Trial name,_time_this_iter_s,_timestamp,_training_iteration,acc,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
TorchTrainer_0e7c2_00000,2.69991,1675649747,15,0.921569,2023-02-05_18-15-49,True,,6527e0b6261b4a24814e1f0bb5fedc3a,0,ip-10-0-34-113,15,0.236638,10.0.34.113,14558,True,46.7192,2.58934,46.7192,1675649749,0,,15,0e7c2_00000,0.187365


(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 1-train Loss: 0.2617 Acc: 0.8934
(RayTrainWorker pid=18732) Epoch 1-train Loss: 0.3784 Acc: 0.8443
(RayTrainWorker pid=18732) Epoch 1-val Loss: 0.3210 Acc: 0.8693
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 1-val Loss: 0.3210 Acc: 0.8693
(RayTrainWorker pid=18732) Epoch 2-train Loss: 0.2643 Acc: 0.9098
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 2-train Loss: 0.2164 Acc: 0.9262
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 2-val Loss: 0.3782 Acc: 0.8497
(RayTrainWorker pid=18732) Epoch 2-val Loss: 0.3782 Acc: 0.8497
(RayTrainWorker pid=18732) Epoch 3-train Loss: 0.2506 Acc: 0.9262
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 3-train Loss: 0.3308 Acc: 0.8525
(RayTrainWorker pid=18732) Epoch 3-val Loss: 0.5937 Acc: 0.7908
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 3-val Loss: 0.5937 Acc: 0.7908
(RayTrainWorker pid=18732) Epoch 4-train Loss: 0.3079 Acc: 0.8934
(RayTrainWorker pid=14635, ip=10.0.34.113) Epoch 4-train L

2023-02-05 18:15:51,839	INFO tune.py:762 -- Total run time: 52.49 seconds (52.24 seconds for the tuning loop).


The training procedure completed in 53 seconds, it saved the best checkpoint in the `local_dir` provided to the trainer. You can now check the experiment metrics and checkpoint information:

In [8]:
print(result)

Result(metrics={'loss': 0.23663791493524863, 'acc': 0.9215686274509803, '_timestamp': 1675649747, '_time_this_iter_s': 2.6999073028564453, '_training_iteration': 15, 'should_checkpoint': True, 'done': True, 'trial_id': '0e7c2_00000', 'experiment_tag': '0'}, error=None, log_dir=PosixPath('/mnt/cluster_storage/ray_results/resnet-finetune/TorchTrainer_0e7c2_00000_0_2023-02-05_18-14-59'))


## Load the checkpoint for batch prediction:

TorchTrainer has already saved the best model parameters in `log_dir`. Now we want to load this model into memory and perform batch prediction and evaluation on test data.
`TorchCheckpoint.from_directory` will automatically extract pickled params. BatchPredictor will identify the dict key "model", and load the corresponding parameters into model. You can also specify the 
 

In [10]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

ckpt = TorchCheckpoint.from_directory("/mnt/cluster_storage/ray_results/resnet-finetune/TorchTrainer_0e7c2_00000_0_2023-02-05_18-14-59/checkpoint_000014")
predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor, model=initialize_model(configs["num_classes"]))



In [11]:
prediction_ds = predictor.predict(ray_datasets["val"], feature_columns=["image"], keep_columns=["label"])
print(prediction_ds.schema())
print(prediction_ds.take(5))

Map Progress (3 actors 1 pending): 100%|██████████| 1/1 [00:13<00:00, 13.48s/it]


predictions: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64
[{'predictions': array([ 1.525892 , -2.5796869], dtype=float32), 'label': 0}, {'predictions': array([ 1.4257711 , -0.95987207], dtype=float32), 'label': 0}, {'predictions': array([ 2.4064574, -3.74627  ], dtype=float32), 'label': 0}, {'predictions': array([ 1.2741024, -1.4860458], dtype=float32), 'label': 0}, {'predictions': array([ 4.2864375, -3.927444 ], dtype=float32), 'label': 0}]


## Evaluate predictions results
The BatchPredictor returns a ray dataset as result, which consists a column of `predictions` and the columns specified by `keep_columns` argument. The `predictions` column contains the model's tensor output. Here we define a function `convert_logits_to_classes` to convert tensor outputs to labels. 

In [12]:

import pandas as pd 

def convert_logits_to_classes(df):
    pred_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = pred_class
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]

predictions = prediction_ds.map_batches(convert_logits_to_classes)
predictions.show(1)

print("Evaluation Accuracy = ", predictions.mean(on="correct"))

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 63.72it/s]


{'prediction': 0, 'label': 0, 'correct': True}


Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 111.99it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00,  7.37it/s]


Evaluation Accuracy =  0.9215686274509803


Instead of rewriting a new evaluation function in pandas format, one can also reuse the evaluation function they used in the training loop.

In [13]:
def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

accuracy = 0
for batch in prediction_ds.iter_torch_batches(batch_size=10):
    accuracy += evaluate(batch["predictions"], batch["label"])
accuracy /= prediction_ds.count()

print("Evaluation Accuracy = ", accuracy)

Evaluation Accuracy =  0.9215686274509803


This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.