[ref](https://github.com/taipingeric/ML-for-Newbies/blob/main/Cross%20validation/KFold-CrossValidation-PyTorch.ipynb)

In [1]:
# Basic module
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pprint import pprint

# PyTorch
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

In [2]:
# print version of PyTorch
torch.__version__, torchvision.__version__

('1.12.1+cpu', '0.13.1+cpu')

In [3]:
# Define Config
NUM_CLASS = 10
# Build dataset with data preprocess
preprocess = transforms.Compose([transforms.ToTensor()])

train_ds = torchvision.datasets.MNIST('data', 
                                      train=True, 
                                      download=True, 
                                      transform=preprocess)
test_ds = torchvision.datasets.MNIST('data', 
                                     train=False, 
                                     download=True, 
                                     transform=preprocess)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw



In [4]:
print(len(train_ds), len(test_ds))

60000 10000


# build model

In [5]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

IMG_SIZE = 28

Using cpu


In [6]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, NUM_CLASS),
        )
    def forward(self, x):
        logits = self.net(x)
        return logits

# init model and move to GPU device
model = NeuralNet().to(device)

In [7]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch

    model.train() # Sets the model in training mode.
    epoch_loss, epoch_correct = 0, 0

    for batch_i, (x, y) in enumerate(tqdm(dataloader, leave=False)):
        x, y = x.to(device), y.to(device) # move data to GPU

        # Compute prediction loss
        pred = model(x)
        loss = loss_fn(pred, y)

        # Optimization by gradients
        optimizer.zero_grad() # set prevision gradient to 0
        loss.backward() # backpropagation to compute gradients
        optimizer.step() # update model params

        # write to logs
        epoch_loss += loss.item()
        epoch_correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    # return avg loss of epoch, acc of epoch
    return epoch_loss/num_batches, epoch_correct/size
    

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch

    model.eval() # Sets the model in test mode.
    epoch_loss, epoch_correct = 0, 0

    # No training for test data
    with torch.no_grad():
        for batch_i, (x, y) in enumerate(tqdm(dataloader, leave=False)):
            x, y = x.to(device), y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            epoch_loss += loss.item()
            epoch_correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    return epoch_loss/num_batches, epoch_correct/size

In [8]:
# backup initial weights
PATH = 'init.pth'
torch.save(model.state_dict(), PATH)

In [9]:
K = 5
EPOCHS = 1
BATCH_SIZE = 256
loss_fn = nn.CrossEntropyLoss()
test_loader = torch.utils.data.DataLoader(test_ds, 
                                          batch_size=BATCH_SIZE)

from sklearn.model_selection import KFold
kfold = KFold(n_splits=K)

fold_losses = []
fold_accs = []

for fold_i, (train_ids, val_ids) in enumerate(kfold.split(train_ds)):
    print(f'train size: {len(train_ids)}, val size: {len(val_ids)}')

    # Reset model parameters
    model.load_state_dict(torch.load(PATH))

    # Sample elements from selected ids
    train_sampler = torch.utils.data.SubsetRandomSampler(train_ids)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_ids)
    # Use sampler to select data for training and validation
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE,
                                             sampler=val_sampler)
    
    optimizer = torch.optim.Adam(params=model.parameters())
    # Training
    for epoch in tqdm(range(EPOCHS), leave=False):
        train_loss, train_acc = train(train_loader, model, loss_fn, optimizer)
        val_loss, val_acc = test(val_loader, model, loss_fn)
    
    # Test
    test_loss, test_acc = test(test_loader, model, loss_fn)
    print(f'Fold {fold_i}, test acc: {test_acc:.3f}')

    fold_losses.append(test_loss)
    fold_accs.append(test_acc)

train size: 48000, val size: 12000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Fold 0, test acc: 0.218
train size: 48000, val size: 12000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Fold 1, test acc: 0.225
train size: 48000, val size: 12000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Fold 2, test acc: 0.228
train size: 48000, val size: 12000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Fold 3, test acc: 0.224
train size: 48000, val size: 12000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Fold 4, test acc: 0.227


In [10]:
print(f"Loss: mean {np.mean(fold_losses):.3f}, std: {np.std(fold_losses):.3f}")
print(f"Acc: mean {np.mean(fold_accs):.3f}, std: {np.std(fold_accs):.3f}")

Loss: mean 2.086, std: 0.012
Acc: mean 0.224, std: 0.003


In [19]:
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([1, 2, 3, 4])
kf = KFold(n_splits=2)

for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]
