# Finetuning a Pytorch ResNet Model for Image Classification
In this example we will finetune a pretrained ResNet model with Ray Train. You should be familiar with [PyTorch](https://pytorch.org/) before starting the tutorial. 

For fine-tuning, our network architecture consists of a pretrained ResNet model as the backbone and a randomly initialized linear layer as the classifier. The ResNet model is pretrained on the 1000-class Imagenet dataset. We will unfreeze and retrain all parameters of the model for the new task.




# Load and transform datasets
We will use the *hymenoptera_data* as the fintuning dataset, which contains two classes(bees and ants) and 397 images(244 for train, 153 for validation). The dataset is provided by Pytorch and can be downloaded [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip). The dataset folder is structured such that we can load with Pytorch [ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) dataset.

Notice that the ResNet model was pretrained with hard-coded normalization values. We'll keep these numbers the same for fine-tuning, as shown in *data_transforms*. More details can be found [here](https://pytorch.org/hub/pytorch_vision_resnet/).

In [None]:
import os
os.system("wget https://download.pytorch.org/tutorial/hymenoptera_data.zip")
os.system("unzip hymenoptera_data.zip")

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np

# Replace with your own path of the dataset
DATA_DIR = "./hymenoptera_data"

# Data augmentation and normalization for training
# Just normalization for validation
input_size = 224
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

torch_datasets = dict()
for split in ["train", "val"]:
    torch_datasets[split] = datasets.ImageFolder(os.path.join(DATA_DIR, split), data_transforms[split])

Next we will transform our ImageFolder dataset into a Ray dataset, which will partition the whole dataset and distribute the data blocks across the nodes in cluster. You will benefit from faster parallel pre-processing and data ingestion.

Note that **batch** here refers to the chunk of data that the map function will execute on, not the batch we use for model training. To learn more about writing functions for {meth}`map_batches <ray.data.Dataset.map_batches>`, read [writing user-defined functions](https://docs.ray.io/en/latest/data/transforming-datasets.html#transform-datasets-writing-udfs) for more details.

In [2]:
import ray

def convert_batch_to_numpy(batch):
    images = np.array([image.numpy() for image, _ in batch])
    labels = np.array([label for _, label in batch])
    return {"image": images, "label": labels}

ray_datasets = dict()
for split in ["train", "val"]:
    ray_datasets[split] = ray.data.from_torch(torch_datasets[split]).map_batches(convert_batch_to_numpy)
    print(ray_datasets[split].schema())

2023-02-11 20:07:46,488	INFO worker.py:1352 -- Connecting to existing Ray cluster at address: 10.0.29.191:6379...
2023-02-11 20:07:46,631	INFO worker.py:1529 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2023-02-11 20:07:47,037	INFO packaging.py:373 -- Pushing file package 'gcs://_ray_pkg_36533f2e8a8ea9ade3d97c03aa113b26.zip' (135.91MiB) to Ray cluster...
2023-02-11 20:07:49,572	INFO packaging.py:386 -- Successfully pushed file package 'gcs://_ray_pkg_36533f2e8a8ea9ade3d97c03aa113b26.zip'.
Map_Batches: 100%|██████████| 244/244 [00:07<00:00, 33.66it/s] 


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


Map_Batches: 100%|██████████| 153/153 [00:00<00:00, 287.58it/s]


image: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64


## Initialize Model and Fine-tuning configs

In [3]:
def initialize_model(num_classes):
    # Load pretrained model params
    model = models.resnet18(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)

    # Ensure all params get updated during fintuning
    for param in model.parameters():
        param.requires_grad = True
    return model

In [4]:
configs = dict()

# Input image size (224 x 224)
configs["input_size"] = 224

# Number of label classes
configs["num_classes"] = 2

# Batch size for training (change depending on how much memory you have)
configs["batch_size"] = 16

# Number of epochs to train for
configs["num_epochs"] = 15

# Hyper-parameters for optimizer
configs["lr"] = 0.001
configs["momentum"] = 0.9

## Define the Training Loop

The `train_loop_per_worker` function defines the finetuning procedure for each worker.

**1. Load dataset shard for each worker**:
- A ray trainer will take a dictionary of ray datasets as input. One can accessed these data by `session.get_dataset_shard(DATASET_KEY)` in the workers.
- Only the dataset with key "train" will be split into multiple shards, while all the others will remain the same. 
- One can use {meth}`iter_torch_batches <ray.data.Dataset.iter_torch_batches>` to iterate the datasets with automatic tensor batching. If you need more flexible customized batching function, please refer to our lower-level api {meth}`iter_batches <ray.data.Dataset.iter_batches>`.

**2. Prepare your model**:
- `train.torch.prepare_model` will prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, and synchronizes the gradients and buffers across all workers.

**3. Report metrics and checkpoint**:
- `session.report` will gather the metrics from each worker and save them into log files.
- You don't have to save checkpoints manually with `torch.save()`, `session.report()` will help you sync checkpoints to local/cloud storage.
- The best checkpoints will be saved according to the specified `checkpoint_score_attribute` in {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Here we only save the best model with highest validation accuracy.

In [5]:
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint

def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

def train_loop_per_worker(configs):
    # Prepare dataloader for each worker
    datasets = dict()
    datasets["train"] = session.get_dataset_shard("train")
    datasets["val"] = session.get_dataset_shard("val")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // session.get_world_size()

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model(num_classes=configs["num_classes"])
    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(model.parameters(), lr=configs["lr"], momentum=configs["momentum"])
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Create a dataset iterator for the shard on the current worker
            dataset_iterator = datasets[phase].iter_torch_batches(batch_size=worker_batch_size)
            for batch in dataset_iterator:
                inputs = batch["image"].to(device)
                labels = batch["label"].to(device)
            
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            epoch_loss = running_loss / datasets[phase].count()
            epoch_acc = running_corrects / datasets[phase].count()

            print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, phase, epoch_loss, epoch_acc))

            # Report metrics and checkpoint every epoch
            if phase == "val":
                checkpoint = TorchCheckpoint.from_dict(
                    {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()
                    }
                )
                session.report(metrics={"loss": epoch_loss, "acc": epoch_acc}, checkpoint=checkpoint)


Next, setup the TorchTrainer:

In [6]:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
from ray.tune.syncer import SyncConfig

# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(num_workers=4, use_gpu=True, resources_per_worker={"CPU": 4, "GPU": 1})

# Save the best checkpoint with highest validation accuracy 
checkpoint_config = CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="acc", checkpoint_score_order="max")

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="resnet-finetune",
    local_dir="/tmp/ray_results",
    sync_config=SyncConfig(),
    checkpoint_config=checkpoint_config
)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=configs,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

The training procedure completed in 53 seconds, it saved the best checkpoint in the `local_dir` provided to the trainer. You can now check the experiment metrics and checkpoint information:

In [7]:
result = trainer.fit()
print(result)

0,1
Current time:,2023-02-11 20:09:59
Running for:,00:01:07.54
Memory:,26.4/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,acc,_timestamp
TorchTrainer_f56f5_00000,TERMINATED,10.0.50.149:4404,15,56.321,0.342963,0.888889,1676174990


(RayTrainWorker pid=4481, ip=10.0.50.149) 2023-02-11 20:08:57,904	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=4315, ip=10.0.62.175) 2023-02-11 20:08:59,066	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=1964717) 2023-02-11 20:08:59,067	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=2126, ip=10.0.53.38) 2023-02-11 20:08:59,047	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=4481, ip=10.0.50.149) 2023-02-11 20:08:59,059	INFO train_loop_utils.py:270 -- Moving model to device: cuda:0
(RayTrainWorker pid=4315, ip=10.0.62.175) 2023-02-11 20:09:00,583	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=1964717) 2023-02-11 20:09:00,636	INFO train_loop_utils.py:330 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=4481, ip=10.0.50.149) 2023-02-11 20:09:00,60

(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 0-train Loss: 0.5419 Acc: 0.7869
(RayTrainWorker pid=1964717) Epoch 0-train Loss: 0.5718 Acc: 0.7213
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 0-train Loss: 0.8408 Acc: 0.4754
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 0-train Loss: 0.7600 Acc: 0.5410
(RayTrainWorker pid=1964717) Epoch 0-val Loss: 0.4704 Acc: 0.7647
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 0-val Loss: 0.4704 Acc: 0.7647
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 0-val Loss: 0.4704 Acc: 0.7647
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 0-val Loss: 0.4704 Acc: 0.7647


Trial name,_time_this_iter_s,_timestamp,_training_iteration,acc,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
TorchTrainer_f56f5_00000,3.06015,1676174990,15,0.888889,2023-02-11_20-09-51,True,,e2bfb42152344b97bd6089d2c1bf2c1b,0,ip-10-0-50-149,15,0.342963,10.0.50.149,4404,True,56.321,3.10375,56.321,1676174991,0,,15,f56f5_00000,0.183336


(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 1-train Loss: 0.2978 Acc: 0.9016
(RayTrainWorker pid=1964717) Epoch 1-train Loss: 0.6057 Acc: 0.6721
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 1-train Loss: 0.5975 Acc: 0.6721
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 1-train Loss: 0.5828 Acc: 0.6557
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 1-val Loss: 0.5355 Acc: 0.7124
(RayTrainWorker pid=1964717) Epoch 1-val Loss: 0.5355 Acc: 0.7124
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 1-val Loss: 0.5355 Acc: 0.7124
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 1-val Loss: 0.5355 Acc: 0.7124
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 2-train Loss: 0.4905 Acc: 0.7377
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 2-train Loss: 0.3701 Acc: 0.8525
(RayTrainWorker pid=1964717) Epoch 2-train Loss: 0.8300 Acc: 0.5574
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 2-train Loss: 0.6091 Acc: 0.6066
(RayTrainWorker pid=1964717) Epoch 2-val Loss: 0.5647 Acc: 0.7190
(RayTrainWor



(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 8-train Loss: 0.2025 Acc: 0.9016
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 8-train Loss: 0.1332 Acc: 0.9836
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 8-train Loss: 0.0755 Acc: 1.0000
(RayTrainWorker pid=1964717) Epoch 8-train Loss: 0.2536 Acc: 0.9180
(RayTrainWorker pid=1964717) Epoch 8-val Loss: 0.3055 Acc: 0.8824
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 8-val Loss: 0.3055 Acc: 0.8824
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 8-val Loss: 0.3055 Acc: 0.8824
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 8-val Loss: 0.3055 Acc: 0.8824




(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 9-train Loss: 0.0725 Acc: 1.0000
(RayTrainWorker pid=1964717) Epoch 9-train Loss: 0.2266 Acc: 0.9836
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 9-train Loss: 0.1900 Acc: 0.9672
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 9-train Loss: 0.1149 Acc: 0.9836
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 9-val Loss: 0.3029 Acc: 0.8758
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 9-val Loss: 0.3029 Acc: 0.8758
(RayTrainWorker pid=1964717) Epoch 9-val Loss: 0.3029 Acc: 0.8758
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 9-val Loss: 0.3029 Acc: 0.8758




(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 10-train Loss: 0.1107 Acc: 0.9836
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 10-train Loss: 0.1787 Acc: 0.9672
(RayTrainWorker pid=1964717) Epoch 10-train Loss: 0.2134 Acc: 0.9836
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 10-train Loss: 0.0671 Acc: 1.0000
(RayTrainWorker pid=1964717) Epoch 10-val Loss: 0.3204 Acc: 0.8758
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 10-val Loss: 0.3204 Acc: 0.8758
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 10-val Loss: 0.3204 Acc: 0.8758
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 10-val Loss: 0.3204 Acc: 0.8758




(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 11-train Loss: 0.1709 Acc: 0.9672
(RayTrainWorker pid=1964717) Epoch 11-train Loss: 0.1989 Acc: 0.9836
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 11-train Loss: 0.0642 Acc: 1.0000
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 11-train Loss: 0.1046 Acc: 0.9836
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 11-val Loss: 0.3125 Acc: 0.8758
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 11-val Loss: 0.3125 Acc: 0.8758
(RayTrainWorker pid=1964717) Epoch 11-val Loss: 0.3125 Acc: 0.8758
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 11-val Loss: 0.3125 Acc: 0.8758




(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 12-train Loss: 0.0609 Acc: 1.0000
(RayTrainWorker pid=1964717) Epoch 12-train Loss: 0.1879 Acc: 0.9344
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 12-train Loss: 0.1594 Acc: 0.9672
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 12-train Loss: 0.0977 Acc: 0.9672
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 12-val Loss: 0.3345 Acc: 0.8758
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 12-val Loss: 0.3345 Acc: 0.8758
(RayTrainWorker pid=1964717) Epoch 12-val Loss: 0.3345 Acc: 0.8758
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 12-val Loss: 0.3345 Acc: 0.8758




(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 13-train Loss: 0.0903 Acc: 0.9836
(RayTrainWorker pid=1964717) Epoch 13-train Loss: 0.1829 Acc: 0.9836
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 13-train Loss: 0.0582 Acc: 1.0000
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 13-train Loss: 0.1531 Acc: 0.9836
(RayTrainWorker pid=1964717) Epoch 13-val Loss: 0.3233 Acc: 0.8758
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 13-val Loss: 0.3233 Acc: 0.8758
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 13-val Loss: 0.3233 Acc: 0.8758
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 13-val Loss: 0.3233 Acc: 0.8758




(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 14-train Loss: 0.1386 Acc: 1.0000
(RayTrainWorker pid=1964717) Epoch 14-train Loss: 0.1710 Acc: 0.9180
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 14-train Loss: 0.0554 Acc: 1.0000
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 14-train Loss: 0.0840 Acc: 1.0000
(RayTrainWorker pid=2126, ip=10.0.53.38) Epoch 14-val Loss: 0.3430 Acc: 0.8889
(RayTrainWorker pid=1964717) Epoch 14-val Loss: 0.3430 Acc: 0.8889
(RayTrainWorker pid=4315, ip=10.0.62.175) Epoch 14-val Loss: 0.3430 Acc: 0.8889
(RayTrainWorker pid=4481, ip=10.0.50.149) Epoch 14-val Loss: 0.3430 Acc: 0.8889


2023-02-11 20:09:59,693	INFO tune.py:762 -- Total run time: 67.87 seconds (67.54 seconds for the tuning loop).


Result(metrics={'loss': 0.34296345968960534, 'acc': 0.8888888888888888, '_timestamp': 1676174990, '_time_this_iter_s': 3.0601508617401123, '_training_iteration': 15, 'should_checkpoint': True, 'done': True, 'trial_id': 'f56f5_00000', 'experiment_tag': '0'}, error=None, log_dir=PosixPath('/tmp/ray_results/resnet-finetune/TorchTrainer_f56f5_00000_0_2023-02-11_20-08-52'))


## Load the checkpoint for batch prediction:

TorchTrainer has already saved the best model parameters in `log_dir`. Now we want to load this model into memory and perform batch prediction and evaluation on test data.
`TorchCheckpoint.from_directory` will automatically extract pickled params. BatchPredictor will identify the dict key "model", and load the corresponding parameters into model. You can also specify the 
 

The log and checkpoints will be saved into `local_dir` specified in TrainerTrainer. The format of Checkpoint path is "{local_dir}/{experiment_name}/{trail_name}/{checkpoint_name}".

In [8]:
checkpoint_folder = "/tmp/ray_results/resnet-finetune/TorchTrainer_42662_00000_0_2023-02-11_19-56-42/checkpoint_000014"

In [10]:
import warnings
warnings.filterwarnings("ignore")
checkpoint_folder = result.checkpoint.uri[7:]

In [11]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

ckpt = TorchCheckpoint.from_directory(checkpoint_folder)
predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor, model=initialize_model(configs["num_classes"]))

In [12]:
prediction_ds = predictor.predict(ray_datasets["val"], feature_columns=["image"], keep_columns=["label"])
print(prediction_ds.schema())
print(prediction_ds.take(5))

Map Progress (3 actors 1 pending): 100%|██████████| 1/1 [00:16<00:00, 16.01s/it]


predictions: extension<arrow.py_extension_type<ArrowTensorType>>
label: int64
[{'predictions': array([ 2.3760695, -0.7187623], dtype=float32), 'label': 0}, {'predictions': array([ 2.1758888, -2.2852275], dtype=float32), 'label': 0}, {'predictions': array([ 2.9107504, -2.6489625], dtype=float32), 'label': 0}, {'predictions': array([ 2.845186, -2.409718], dtype=float32), 'label': 0}, {'predictions': array([ 3.203075 , -3.5486095], dtype=float32), 'label': 0}]


## Evaluate predictions results
The BatchPredictor returns a ray dataset as result, which consists a column of `predictions` and the columns specified by `keep_columns` argument. The `predictions` column contains the model's tensor output. Here we define a function `convert_logits_to_classes` to convert tensor outputs to labels. 

In [13]:
import pandas as pd 

def convert_logits_to_classes(df):
    pred_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = pred_class
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]

predictions = prediction_ds.map_batches(convert_logits_to_classes)
predictions.show(1)

print("Evaluation Accuracy = ", predictions.mean(on="correct"))

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 62.13it/s]


{'prediction': 0, 'label': 0, 'correct': True}


Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 111.20it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 168.24it/s]


Evaluation Accuracy =  0.8888888888888888


Instead of rewriting a new evaluation function in pandas format, one can also reuse the evaluation function they wrote in the training loop. Note that the previous approach using `map_batches()` is more efficient than iterating over the dataset, because it parallelizes the evaluation on each partition.

In [14]:
def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

accuracy = 0
for batch in prediction_ds.iter_torch_batches(batch_size=10):
    accuracy += evaluate(batch["predictions"], batch["label"])
accuracy /= prediction_ds.count()

print("Evaluation Accuracy = ", accuracy)

Evaluation Accuracy =  0.8888888888888888


This example is adapted from Pytorch's [Fintuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.