# Finetuning a Pytorch ResNet Model for Image Classification
In this example we will finetune a pretrained ResNet model with Ray Train. You should be familiar with [PyTorch](https://pytorch.org/) before starting the tutorial. 

For fine-tuning, our network architecture consists of a pretrained ResNet model as the backbone and a randomly initialized linear layer as the classifier. The ResNet model is pretrained on the 1000-class Imagenet dataset. We will unfreeze and retrain all parameters of the model for the new task.




# Load and transform datasets
We will use the *hymenoptera_data* as the fintuning dataset, which contains two classes(bees and ants) and 397 images(244 for train, 153 for validation). The dataset is provided by Pytorch and can be downloaded [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip). The dataset folder is structured such that we can load with Pytorch [ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) dataset.

Notice that the ResNet model was pretrained with hard-coded normalization values. We'll keep these numbers the same for fine-tuning, as shown in *data_transforms*. More details can be found [here](https://pytorch.org/hub/pytorch_vision_resnet/).

In [15]:
import os

os.system("wget https://download.pytorch.org/tutorial/hymenoptera_data.zip")
os.system("unzip hymenoptera_data.zip")


In [16]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np

torch.manual_seed(620)
# torch.manual_seed(450)

# Replace with your own path of the dataset
DATA_DIR = "./hymenoptera_data"

# Data augmentation and normalization for training
# Just normalization for validation
input_size = 224
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

torch_datasets = dict()
for split in ["train", "val"]:
    torch_datasets[split] = datasets.ImageFolder(
        os.path.join(DATA_DIR, split), data_transforms[split]
    )


Next we will transform our ImageFolder dataset into a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion.

Note that **batch** here refers to the chunk of data that the map function will execute on, not the batch we use for model training. To learn more about writing functions for {meth}`map_batches <ray.data.Dataset.map_batches>`, read [writing user-defined functions](https://docs.ray.io/en/latest/data/transforming-datasets.html#transform-datasets-writing-udfs) for more details.

In [17]:
import ray


def convert_batch_to_numpy(batch):
    images = np.array([image.numpy() for image, _ in batch])
    labels = np.array([label for _, label in batch])
    return {"image": images, "label": labels}


ray_datasets = dict()
for split in ["train", "val"]:
    ray_datasets[split] = ray.data.from_torch(torch_datasets[split]).map_batches(
        convert_batch_to_numpy
    )
    print(ray_datasets[split].schema())


Map_Batches: 100%|██████████| 244/244 [00:00<00:00, 245.02it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


Map_Batches: 100%|██████████| 153/153 [00:00<00:00, 284.32it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


## Initialize Model and Fine-tuning configs

In [18]:
configs = dict()

# Input image size (224 x 224)
configs["input_size"] = 224

# Batch size for training (change depending on how much memory you have)
configs["batch_size"] = 32

# Number of epochs to train for
configs["num_epochs"] = 10

# Hyper-parameters for optimizer
configs["lr"] = 0.001
configs["momentum"] = 0.9


Next let's define our model, you can create a model from a pretrained ResNet, or reload the model checkpoint from a previous run.

In [19]:
from ray.train.torch import TorchCheckpoint


def initialize_model():
    # Load pretrained model params
    model = models.resnet18(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model


# You can also initialize a model from previous checkpoint
CHECKPOINT_URI = "s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000000/"


def initialize_model_from_ckpt():
    checkpoint = TorchCheckpoint.from_uri(CHECKPOINT_URI)
    resnet18 = initialize_model()
    return checkpoint.get_model(model=resnet18)


## Define the Training Loop

The `train_loop_per_worker` function defines the finetuning procedure for each worker.

**1. Load dataset shard for each worker**:
- A ray trainer will take a dictionary of ray datasets as input. One can accessed these data by `session.get_dataset_shard(DATASET_KEY)` in the workers.
- Only the dataset with key "train" will be split into multiple shards, while all the others will remain the same. 
- One can use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching. If you need more flexible customized batching function, please refer to our lower-level api {meth}`iter_batches <ray.data.Dataset.iter_batches>`.

**2. Prepare your model**:
- `train.torch.prepare_model` will prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, and synchronizes the gradients and buffers across all workers.

**3. Report metrics and checkpoint**:
- `session.report` will gather the metrics from each worker and save them into log files.
- You don't have to save checkpoints manually with `torch.save()`, `session.report()` will help you sync checkpoints to local/cloud storage.
- The best checkpoints will be saved according to the specified `checkpoint_score_attribute` in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Here we only save the best model with highest validation accuracy.

In [20]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint


def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


def train_loop_per_worker(configs):
    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model_from_ckpt()
    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(
                batch_size=worker_batch_size, device=device
            )
            for batch in dataset_iterator:
                inputs = batch["image"]
                labels = batch["label"]

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            if session.get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    }
                )
                session.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                    checkpoint=checkpoint,
                )


Next, setup the TorchTrainer:

In [21]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

# Save the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="finetune-resnet",
    local_dir="/tmp/ray_results",
    sync_config=SyncConfig(),
    checkpoint_config=checkpoint_config,
)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=configs,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)


The training procedure completed in 53 seconds, it saved the best checkpoint in the `local_dir` provided to the trainer. You can now check the experiment metrics and checkpoint information:

In [22]:
result = trainer.fit()
print(result)


0,1
Current time:,2023-02-14 14:41:27
Running for:,00:00:59.23
Memory:,10.1/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc,_timestamp
TorchTrainer_94bb5_00000,TERMINATED,10.0.13.194:6490,10,47.268,0.211976,0.934641,1676414478


(RayTrainWorker pid=6567, ip=10.0.13.194) 2023-02-14 14:40:34,581	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=967689) 2023-02-14 14:40:44,017	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=7300, ip=10.0.43.115) 2023-02-14 14:40:44,041	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=4757, ip=10.0.6.12) 2023-02-14 14:40:44,052	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=6567, ip=10.0.13.194) 2023-02-14 14:40:44,038	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=7300, ip=10.0.43.115) 2023-02-14 14:40:45,546	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=6567, ip=10.0.13.194) 2023-02-14 14:40:45,542	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=967689) 2023-02-14 14:40:45,560	I

(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 0-train Loss: 0.5375 Acc: 0.7213
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 0-val Loss: 0.5027 Acc: 0.7320


Trial name,_time_this_iter_s,_timestamp,_training_iteration,acc,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
TorchTrainer_94bb5_00000,3.00367,1676414478,10,0.934641,2023-02-14_14-41-19,True,,f8c3ff154b024c7b8741099e97057f54,0,ip-10-0-13-194,10,0.211976,10.0.13.194,6490,True,47.268,3.01411,47.268,1676414479,0,,10,94bb5_00000,0.197936


(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 1-train Loss: 0.4045 Acc: 0.8689
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 1-val Loss: 0.3641 Acc: 0.9085




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 2-train Loss: 0.2317 Acc: 0.9508
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 2-val Loss: 0.2869 Acc: 0.9085




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 3-train Loss: 0.1685 Acc: 0.9672
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 3-val Loss: 0.2549 Acc: 0.9281




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 4-train Loss: 0.1130 Acc: 0.9836
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 4-val Loss: 0.2357 Acc: 0.9281




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 5-train Loss: 0.0835 Acc: 1.0000
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 5-val Loss: 0.2250 Acc: 0.9281




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 6-train Loss: 0.0631 Acc: 1.0000
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 6-val Loss: 0.2205 Acc: 0.9346




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 7-train Loss: 0.0475 Acc: 1.0000
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 7-val Loss: 0.2165 Acc: 0.9346




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 8-train Loss: 0.0376 Acc: 1.0000
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 8-val Loss: 0.2134 Acc: 0.9346




(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 9-train Loss: 0.0306 Acc: 1.0000
(RayTrainWorker pid=6567, ip=10.0.13.194) Epoch 9-val Loss: 0.2120 Acc: 0.9346


2023-02-14 14:41:28,107	INFO tune.py:762 -- Total run time: 59.35 seconds (59.23 seconds for the tuning loop).


Result(metrics={'loss': 0.21197572525809794, 'acc': 0.934640522875817, '_timestamp': 1676414478, '_time_this_iter_s': 3.0036728382110596, '_training_iteration': 10, 'should_checkpoint': True, 'done': True, 'trial_id': '94bb5_00000', 'experiment_tag': '0'}, error=None, log_dir=PosixPath('/tmp/ray_results/finetune-resnet/TorchTrainer_94bb5_00000_0_2023-02-14_14-40-28'))


## Load the checkpoint for batch prediction:

TorchTrainer has already saved the best model parameters in `log_dir`. Now we want to load this model into memory and perform batch prediction and evaluation on test data.
`TorchCheckpoint.from_directory` will automatically extract pickled params. BatchPredictor will identify the dict key "model", and load the corresponding parameters into model. You can also specify the 
 

The log and checkpoints will be saved into `local_dir` specified in TrainerTrainer. For example:

In [23]:
checkpoint_folder = "/tmp/ray_results/finetune-resnet/TorchTrainer_94bb5_00000_0_2023-02-14_14-40-28/checkpoint_000009"


In [24]:
import warnings

warnings.filterwarnings("ignore")
checkpoint_folder = result.checkpoint.uri[7:]

In [25]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

ckpt = TorchCheckpoint.from_directory(checkpoint_folder)
predictor = BatchPredictor.from_checkpoint(
    ckpt, TorchPredictor, model=initialize_model()
)

In [26]:
prediction_ds = predictor.predict(
    ray_datasets["val"],
    feature_columns=["image"],
    keep_columns=["label"],
    num_gpus_per_worker=1,
)
print(prediction_ds.schema())
print(prediction_ds.take(1))

2023-02-14 14:41:29,186	INFO batch_predictor.py:184 -- `num_gpus_per_worker` is set for `BatchPreditor`.Automatically enabling GPU prediction for this predictor. To disable set `use_gpu` to `False` in `BatchPredictor.predict`.
Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:08<00:00,  8.41s/it]


predictions: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64
[{'predictions': array([ 0.6944345, -2.0196059], dtype=float32), 'label': 0}]


## Evaluate predictions results
The BatchPredictor returns a ray dataset as result, which consists a column of `predictions` and the columns specified by `keep_columns` argument. The `predictions` column contains the model's tensor output. Here we define a function `convert_logits_to_classes` to convert tensor outputs to labels. 

In [27]:
import pandas as pd


def convert_logits_to_classes(batch):
    batch["pred_label"] = np.argmax(batch["predictions"], axis=1)
    batch["correct"] = batch["pred_label"] == batch["label"]
    return batch


predictions = prediction_ds.map_batches(convert_logits_to_classes, batch_format="numpy")
predictions.show(1)

print("Evaluation Accuracy = ", predictions.mean(on="correct"))


Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 94.94it/s]


{'predictions': array([ 0.6944345, -2.0196059], dtype=float32), 'label': 0, 'pred_label': 0, 'correct': True}


Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 145.21it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 147.17it/s]


Evaluation Accuracy =  0.934640522875817


You can also reuse the evaluation function defined in the training loop by iterating over the dataset. Note that the previous approach using `map_batches()` is more efficient because it parallelizes the evaluation on each partition.

In [28]:
def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


accuracy = 0
for batch in prediction_ds.iter_torch_batches(batch_size=10):
    accuracy += evaluate(batch["predictions"], batch["label"])
accuracy /= prediction_ds.count()

print("Evaluation Accuracy = ", accuracy)


Evaluation Accuracy =  0.934640522875817


This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.