# Finetuning a Pytorch Image Classifier with Ray AIR
In this example we will finetune a pretrained ResNet model with Ray Train. 

For this example, our network architecture consists of the intermediate layer output of a pretrained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.




## Load and preprocess finetuning dataset
This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.
We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and we use this only for demenstration purpose. 

In [None]:
# If you want to run full example, please set SMOKE_TEST as False
SMOKE_TEST = True


In [None]:
import os

os.system("wget https://download.pytorch.org/tutorial/hymenoptera_data.zip")
os.system("unzip hymenoptera_data.zip")


The dataset is structured with directory names as the labels. We use `torchvision.datasets.ImageFolder()` to load the images and their corresponding labels.

In [44]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np

# Replace with your own path of the dataset
DATA_DIR = "./hymenoptera_data"

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

torch_datasets = dict()
for split in ["train", "val"]:
    torch_datasets[split] = datasets.ImageFolder(
        os.path.join(DATA_DIR, split), data_transforms[split]
    )


In [None]:
if SMOKE_TEST:
    from torch.utils.data import Subset

    for split in ["train", "val"]:
        indices = list(range(100))
        torch_datasets[split] = Subset(torch_datasets[split], indices)


Next, we load the images from PyTorch ImageFolder to a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion. 

In [3]:
import ray


def convert_batch_to_numpy(batch):
    images = np.array([image.numpy() for image, _ in batch])
    labels = np.array([label for _, label in batch])
    return {"image": images, "label": labels}


ray_datasets = dict()
for split in ["train", "val"]:
    ray_datasets[split] = ray.data.from_torch(torch_datasets[split]).map_batches(
        convert_batch_to_numpy
    )
    print(ray_datasets[split].schema())


2023-02-28 20:58:28,112	INFO worker.py:1360 -- Connecting to existing Ray cluster at address: 10.0.3.6:6379...
2023-02-28 20:58:28,124	INFO worker.py:1548 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://console.anyscale-staging.com/api/v2/sessions/ses_49hwcjc1pzcddc2nf6cg9itj6b/services?redirect_to=dashboard [39m[22m
2023-02-28 20:58:28,521	INFO packaging.py:330 -- Pushing file package 'gcs://_ray_pkg_281310c9b14127d68ef8fadde956f87a.zip' (136.06MiB) to Ray cluster...
2023-02-28 20:58:31,340	INFO packaging.py:343 -- Successfully pushed file package 'gcs://_ray_pkg_281310c9b14127d68ef8fadde956f87a.zip'.
2023-02-28 20:58:32,007	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(convert_batch_to_numpy)]
MapBatches(convert_batch_to_numpy): 100%|██████████| 200/200 [00:06<00:00, 31.23it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


2023-02-28 20:58:39,675	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(convert_batch_to_numpy)]
MapBatches(convert_batch_to_numpy): 100%|██████████| 153/153 [00:00<00:00, 2200.92it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


```{note}
Note that **batch** here refers to the chunk of data for preprocessing, not the batch for model training. To learn more about writing functions for mapping batches, read [writing user-defined functions](transform_datasets_writing_udfs).
```

## Initialize Model and Fine-tuning configs

Next, let's define the training configuration, which will be passed into training loop function later.

In [4]:
train_loop_config = {
    "input_size": 224,  # Input image size (224 x 224)
    "batch_size": 32,  # Batch size for training
    "num_epochs": 10,  # Number of epochs to train for
    "lr": 0.001,  # Learning Rate
    "momentum": 0.9,  # SGD optimizer momentum
}


Next let's define our model, you can either create a model from a pretrained weights, or reload the model checkpoint from a previous run.

In [5]:
from ray.train.torch import TorchCheckpoint

# Option 1: Initialize model with pretrained weights
def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model


# Option 2: Initialize model with an AIR checkpoint
CHECKPOINT_URI = "s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000000/"


def initialize_model_from_ckpt():
    checkpoint = TorchCheckpoint.from_uri(CHECKPOINT_URI)
    resnet18 = initialize_model()
    return checkpoint.get_model(model=resnet18)


2023-02-28 20:58:39,990	INFO instantiator.py:21 -- Created a temporary directory at /tmp/tmp0wtd7j01
2023-02-28 20:58:39,992	INFO instantiator.py:76 -- Writing /tmp/tmp0wtd7j01/_remote_module_non_scriptable.py


## Define the Training Loop

The `train_loop_per_worker` function defines the finetuning procedure for each worker.

**1. Load dataset shard for each worker**:
- The Trainer will take a dictionary of Ray {class}`~ray.data.Dataset`s as input. These will be preprocessed and accessible in the worker's training loop via {meth}`session.get_dataset_shard() <ray.air.session.get_dataset_shard>`.
- By default, only the dataset under the key "train" will be split into multiple shards. `session.get_dataset_shard()` will return the full dataset for other keys. To configure this, see {class}`~ray.air.DatasetConfig`.
- Use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching and device placement. If you need a more flexible customized batching function, please refer to our lower-level {meth}`iter_batches <ray.data.Dataset.iter_batches>` API.

**2. Prepare your model**:
- {meth}`train.torch.prepare_model() <ray.train.torch.prepare_model>` will prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which will synchronize its weights across all workers.

**3. Report metrics and checkpoint**:
- {meth}`session.report() <ray.air.session.report>` will report metrics and checkpoints to Ray AIR.
- Saving checkpoints through {meth}`session.report(metrics, checkpoint=...) <ray.air.session.report>` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray AIR worker fault tolerance in the future.
- The best checkpoints will be saved according to the specified `checkpoint_score_attribute` in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Here we only save the best model with highest validation accuracy.

In [13]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint


def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


def train_loop_per_worker(configs):
    import warnings

    warnings.filterwarnings("ignore")

    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()
    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(
                batch_size=worker_batch_size, device=device
            )
            for batch in dataset_iterator:
                inputs = batch["image"]
                labels = batch["label"]

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            if session.get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    }
                )
                session.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                    checkpoint=checkpoint,
                )


Next, setup the TorchTrainer:

In [11]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

# Save the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="finetune-resnet",
    local_dir="/tmp/ray_results",
    checkpoint_config=checkpoint_config,
)


In [8]:
if SMOKE_TEST:
    scaling_config = ScalingConfig(
        num_workers=8, use_gpu=False, resources_per_worker={"CPU": 1}
    )
    train_loop_config["num_epochs"] = 1


In [14]:
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

result = trainer.fit()
print(result)


0,1
Current time:,2023-02-28 21:06:20
Running for:,00:00:42.16
Memory:,9.1/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc
TorchTrainer_b4f06_00000,TERMINATED,10.0.3.6:99124,10,36.3842,0.157967,0.96732


(RayTrainWorker pid=99244) 2023-02-28 21:05:45,574	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=99124) 2023-02-28 21:05:46,211	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(RayTrainWorker pid=2059, ip=10.0.47.22) 2023-02-28 21:05:46,949	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=1945, ip=10.0.23.119) 2023-02-28 21:05:46,948	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=99244) 2023-02-28 21:05:46,953	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=1951, ip=10.0.31.3) 2023-02-28 21:05:46,952	INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=1951, ip=10.0.31.3) 2023-02-28 21:05:48,628	INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=2059, ip=10.0.47.22) 2023-02-28 21:05:48,625	

(RayTrainWorker pid=99244) Epoch 0-train Loss: 0.6717 Acc: 0.5574


(RayTrainWorker pid=2059, ip=10.0.47.22) 2023-02-28 21:05:52,228	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]


(RayTrainWorker pid=99244) Epoch 0-val Loss: 0.5739 Acc: 0.5621


Trial name,acc,date,done,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
TorchTrainer_b4f06_00000,0.96732,2023-02-28_21-06-18,True,0,ip-10-0-3-6,10,0.157967,10.0.3.6,99124,True,36.3842,2.61072,36.3842,1677647177,10,b4f06_00000


(RayTrainWorker pid=99244) Epoch 1-train Loss: 0.4580 Acc: 0.8361
(RayTrainWorker pid=99244) Epoch 1-val Loss: 0.3622 Acc: 0.9477
(RayTrainWorker pid=99244) Epoch 2-train Loss: 0.2940 Acc: 0.9508
(RayTrainWorker pid=99244) Epoch 2-val Loss: 0.2779 Acc: 0.9542
(RayTrainWorker pid=99244) Epoch 3-train Loss: 0.2059 Acc: 0.9344
(RayTrainWorker pid=99244) Epoch 3-val Loss: 0.2230 Acc: 0.9673
(RayTrainWorker pid=99244) Epoch 4-train Loss: 0.1383 Acc: 0.9836
(RayTrainWorker pid=99244) Epoch 4-val Loss: 0.1988 Acc: 0.9608
(RayTrainWorker pid=99244) Epoch 5-train Loss: 0.0962 Acc: 1.0000
(RayTrainWorker pid=99244) Epoch 5-val Loss: 0.1810 Acc: 0.9673
(RayTrainWorker pid=99244) Epoch 6-train Loss: 0.0724 Acc: 1.0000
(RayTrainWorker pid=99244) Epoch 6-val Loss: 0.1710 Acc: 0.9673
(RayTrainWorker pid=99244) Epoch 7-train Loss: 0.0511 Acc: 1.0000
(RayTrainWorker pid=99244) Epoch 7-val Loss: 0.1644 Acc: 0.9673
(RayTrainWorker pid=99244) Epoch 8-train Loss: 0.0397 Acc: 1.0000
(RayTrainWorker pid=9924

2023-02-28 21:06:20,577	INFO tune.py:825 -- Total run time: 42.17 seconds (42.16 seconds for the tuning loop).


Result(
  metrics={'loss': 0.15796673334404535, 'acc': 0.9673202614379085, 'should_checkpoint': True, 'done': True, 'trial_id': 'b4f06_00000', 'experiment_tag': '0'},
  log_dir=PosixPath('/tmp/ray_results/finetune-resnet/TorchTrainer_b4f06_00000_0_2023-02-28_21-05-38'),
  checkpoint=TorchCheckpoint(local_path=/tmp/ray_results/finetune-resnet/TorchTrainer_b4f06_00000_0_2023-02-28_21-05-38/checkpoint_000009)
)


## Load the checkpoint for batch prediction:

 
 The metadata and checkpoints have already been saved into `local_dir` specified in TrainerTrainer:

In [52]:
import os

checkpoint_folder = result.checkpoint.uri.replace("file://", "")
print(checkpoint_folder)
os.listdir(checkpoint_folder)


/tmp/ray_results/finetune-resnet/TorchTrainer_b4f06_00000_0_2023-02-28_21-05-38/checkpoint_000009


['.metadata.pkl', 'dict_checkpoint.pkl', '.is_checkpoint', '.tune_metadata']

Now, we want to load the trained model and evaluation it on test data. TorchTrainer has already saved the best model parameters in `log_dir`. We can use {ref}`TorchCheckpoint.from_directory() <ray.train.torch.TorchCheckpoint.from_directory>` to load the resulting checkpoint from our fine-tuning run.

In [53]:
checkpoint = TorchCheckpoint.from_directory(checkpoint_folder)
model = checkpoint.get_model(initialize_model())
device = torch.device("cuda")


In [55]:
if SMOKE_TEST:
    device = torch.device("cpu")


Finally, define a simple evaluation loop and check the checkpoint model performance.

In [56]:
model = model.to(device)
model.eval()

dataloader = DataLoader(torch_datasets["val"], batch_size=32, num_workers=4)
corrects = 0
for inputs, labels in dataloader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    preds = model(inputs)
    corrects += evaluate(preds, labels)

print("Accuracy: ", corrects / len(dataloader.dataset))


Accuracy:  0.9673202614379085
