<!-- ---
title: How to do Cross Validation in Ignite
weight: 7
date: 2021-09-21
downloads: true
sidebar: true
tags:
  - cross validation
--- -->

#  How to do Cross Validation in Ignite

This how-to guide demonstrates how we can do K Fold Cross Validation with PyTorch-Ignite and save the best results.

<!--more-->

In this example, we will be using a [ResNet18](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet18) model on the [MNIST](https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.MNIST) dataset. The base code is the same as used in the [Getting Started Guide](https://pytorch-ignite.ai/tutorials/getting-started/).

In [None]:
!pip install pytorch-ignite

Collecting pytorch-ignite
  Downloading pytorch_ignite-0.4.6-py3-none-any.whl (232 kB)
[?25l[K     |█▍                              | 10 kB 22.2 MB/s eta 0:00:01[K     |██▉                             | 20 kB 27.0 MB/s eta 0:00:01[K     |████▎                           | 30 kB 22.2 MB/s eta 0:00:01[K     |█████▋                          | 40 kB 18.2 MB/s eta 0:00:01[K     |███████                         | 51 kB 13.5 MB/s eta 0:00:01[K     |████████▌                       | 61 kB 12.7 MB/s eta 0:00:01[K     |█████████▉                      | 71 kB 12.1 MB/s eta 0:00:01[K     |███████████▎                    | 81 kB 13.4 MB/s eta 0:00:01[K     |████████████▊                   | 92 kB 14.0 MB/s eta 0:00:01[K     |██████████████                  | 102 kB 12.6 MB/s eta 0:00:01[K     |███████████████▌                | 112 kB 12.6 MB/s eta 0:00:01[K     |█████████████████               | 122 kB 12.6 MB/s eta 0:00:01[K     |██████████████████▎             | 133 kB 1

## Basic Setup

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from sklearn.model_selection import KFold
import numpy as np

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(
            1, 64, kernel_size=3, padding=1, bias=False
        )

    def forward(self, x):
        return self.model(x)


data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
test_dataset = MNIST(download=True, root=".", transform=data_transform, train=False)

Let's concatenate both the datasets first so that we can divide them into k-folds later.

In [None]:
dataset = ConcatDataset([train_dataset, test_dataset])

In [None]:
num_folds = 3
splits = KFold(n_splits=num_folds,shuffle=True,random_state=42)

In [None]:
def setup_dataflow(dataset, train_idx, val_idx):
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=256, sampler=val_sampler)

    return train_loader, val_loader

In [None]:
def initialize():
    model = Net().to(device)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-06)
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion

In [None]:
def train_model(train_loader, val_loader):
    train_results = []
    val_results = []

    model, optimizer, criterion = initialize()

    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
    evaluator = create_supervised_evaluator(model, metrics={"Accuracy": Accuracy(), "Loss": Loss(criterion)}, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        train_results.append(metrics)
        print(f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Accuracy']:.2f} Avg loss: {metrics['Loss']:.2f}")


    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        val_results.append(metrics)
        print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Accuracy']:.2f} Avg loss: {metrics['Loss']:.2f}")

    trainer.run(train_loader, max_epochs=3) 

    return train_results, val_results

In [None]:
results_per_fold = []

for fold_idx, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):

    print('Fold {}'.format(fold_idx + 1))

    train_loader, val_loader = setup_dataflow(dataset, train_idx, val_idx)
    train_results, val_results = train_model(train_loader, val_loader)
    results_per_fold.append([train_results, val_results])

Fold 1


[1/365]   0%|           [00:00<?]

Training Results - Epoch[1] Avg accuracy: 0.68 Avg loss: 1.36
Validation Results - Epoch[1] Avg accuracy: 0.68 Avg loss: 1.37


[1/365]   0%|           [00:00<?]

Training Results - Epoch[2] Avg accuracy: 0.84 Avg loss: 0.86
Validation Results - Epoch[2] Avg accuracy: 0.84 Avg loss: 0.87


[1/365]   0%|           [00:00<?]

Training Results - Epoch[3] Avg accuracy: 0.90 Avg loss: 0.57
Validation Results - Epoch[3] Avg accuracy: 0.89 Avg loss: 0.58
Fold 2


[1/365]   0%|           [00:00<?]

Training Results - Epoch[1] Avg accuracy: 0.72 Avg loss: 1.40
Validation Results - Epoch[1] Avg accuracy: 0.71 Avg loss: 1.41


[1/365]   0%|           [00:00<?]

Training Results - Epoch[2] Avg accuracy: 0.85 Avg loss: 0.89
Validation Results - Epoch[2] Avg accuracy: 0.84 Avg loss: 0.90


[1/365]   0%|           [00:00<?]

Training Results - Epoch[3] Avg accuracy: 0.89 Avg loss: 0.59
Validation Results - Epoch[3] Avg accuracy: 0.89 Avg loss: 0.60
Fold 3


[1/365]   0%|           [00:00<?]

Training Results - Epoch[1] Avg accuracy: 0.73 Avg loss: 1.32
Validation Results - Epoch[1] Avg accuracy: 0.73 Avg loss: 1.33


[1/365]   0%|           [00:00<?]

Training Results - Epoch[2] Avg accuracy: 0.86 Avg loss: 0.83
Validation Results - Epoch[2] Avg accuracy: 0.86 Avg loss: 0.83


[1/365]   0%|           [00:00<?]

Training Results - Epoch[3] Avg accuracy: 0.91 Avg loss: 0.53
Validation Results - Epoch[3] Avg accuracy: 0.90 Avg loss: 0.54


In [None]:
print(results_per_fold)