# Summary
This notebook can be used for training, validation, and testing of multiple PyTorch models on anything from CPU to 8 TPU cores. It makes use of some useful additional features like

* class balancing for highly undersampled classes
* warmstarting from models with fewer to models with more class
* model and stage switching depending on needs

Supported models are ResNets and EfficientNets for specific choices of layers but feel free to extend this to your needs if you like.

Major parts of this notebook were inspired by

* https://www.kaggle.com/andypenrose/pytorch-training-inference-efficientnet-b4
* https://www.kaggle.com/rhtsingh/pytorch-training-inference-efficientnet-baseline/notebook

like the dataset class and the submission handling.

# Imports and dependencies

We start with the necessary imports and dependencies, among which general python imports, special libary imports, and PyTorch imports. The first two are

In [None]:
import os
import glob
import gc
gc.enable()
import time
import random
import warnings
warnings.filterwarnings("ignore")
import multiprocessing

import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from PIL import Image
from sklearn.preprocessing import LabelEncoder

Before other imports however we have to define the main parameters for the run (to remain constants throughout) which will influcence further imports:

In [None]:
# one of: resnet, efficientnet
MODEL = "resnet"
# what stages should be performed (test here is simply generating the submission for the test sets)
KERNEL_STAGES = ["train", "validate", "test"]
# minimum required samples per class, could be 0 for all classes (will be balanced during sampling)
MIN_SAMPLES_PER_CLASS = 0
# split between training and validation, typically 20% or 10% of the training is used for validation
VALIDATION_SPLIT = .1
# recommentded: 128/256 for one TPU core, 16/32 for 8 TPU cores, 128/256 for one GPU
BATCH_SIZE = 256
NUM_EPOCHS = 1
# number of CPU cores and TPU cores if the TPU accelerator is turned on
CPU_CORES = multiprocessing.cpu_count()
TPU_CORES = 0
# log loss and accuracy (and save model) on a given number of minibatches
LOG_FREQ = 32
SAVE_FREQ = 128
# number of previous classes for warmstarting (0 means off)
WARMSTARTING = 0
# number of top predictions
NUM_TOP_PREDICTS = 1
# minimum confidence below which to output nothing (no class)
MIN_CONF = 0.1
# use this switch for proper submission
ENABLE_FAST_SKIP = False
# image width and height in pixes (assumed squares)
IMG_SIZE = 128

We can now continue and finish with all imports:

In [None]:
"""
Steps for 8x TPU acceleration to run on entire chip:
  - set TPU_CORES > 0 (1 or 8)
  - turn on TPU accelerator
  - turn on internet switch
  - exclude the test stage if included (not available for TPUs)
  - load model from a dataset (which might need an update)
  - balance CPU_CORES with TPU_CORES (memory trade-off between the two)
TODO: consider xm.mesh_reduce and xm.master_print to print just one reduced loss or accuracy
TODO: consider whether to scale learning rate as lr * xm.xrt_world_size()
"""
if TPU_CORES > 0:
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
    !python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev
    # use the mantisa-reducing float type
    !export XLA_USE_BF16=1
    import torch_xla
    import torch_xla.core.xla_model as xm
    if TPU_CORES > 1:
        import torch_xla.distributed.parallel_loader as pl
        import torch_xla.distributed.xla_multiprocessing as xmp

import torch
from torch import Tensor
from torch import nn
from torch.optim import lr_scheduler, Adam, SGD
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler
from torchvision import transforms
from torchvision import models

from catalyst.data.sampler import DistributedSamplerWrapper
from tqdm import tqdm
# somewhat fancier
#from tqdm.notebook import tqdm

# uncomment for offline install (needs a minimal dataset)
#!mkdir -p /tmp/pip/cache/
#!cp ../input/glr-efficientnets-ext/efficientnet_pytorch-0.6.3-py3-none-any.whl /tmp/pip/cache/
#!pip install --no-index --find-links /tmp/pip/cache/ efficientnet_pytorch
!pip install efficientnet_pytorch
import efficientnet_pytorch

Now define all paths to be used for the datasets and models:

In [None]:
pretrain_df = pd.read_csv('../input/landmark-recognition-2020/train.csv')
pretest_df = pd.read_csv('../input/landmark-recognition-2020/sample_submission.csv')
train_dir = '../input/landmark-recognition-2020/train/'
test_dir = '../input/landmark-recognition-2020/test/'
model_dir = "../input/pytorch-pretrained-image-models/"
model_load_file = ""  # e.g. "../input/models/glr_resnet_81313.pth"
model_save_file = "glr_resnet.pth"

While seeding is good for reproducibility, considering the constrained resources on Kaggle and the possibility of running just one epoch, I would rather recommend commenting out the `seed_randoms()` line in order to be able to properly sample all classes across multiple sessions.

In [None]:
# seed everything to avoid non-determinism
def seed_randoms(seed=2020):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
# comment this out to not oversample some classes due to short deterministic sessions
seed_randoms()

The dataset class is fairly standard in PyTorch:

In [None]:
class ImageDataset(Dataset):
    """
    Standard class sourced from:

    https://www.kaggle.com/rhtsingh/pytorch-training-inference-efficientnet-baseline
    """

    def __init__(self, dataframe: pd.DataFrame, image_dir:str, mode: str):
        self.df = dataframe
        self.mode = mode
        self.image_dir = image_dir

        transforms_list = []
        if self.mode == 'train':
            # Increase image size from (64,64) to higher resolution,
            # Make sure to change in RandomResizedCrop as well.
            transforms_list = [
                transforms.Resize((IMG_SIZE,IMG_SIZE)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomChoice([
                    transforms.RandomResizedCrop(IMG_SIZE),
                    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
                    transforms.RandomAffine(degrees=15, translate=(0.2, 0.2),
                                            scale=(0.8, 1.2), shear=15,
                                            resample=Image.BILINEAR)
                ]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225]),
            ]
        else:
            transforms_list.extend([
                # Keep this resize same as train
                transforms.Resize((IMG_SIZE,IMG_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225]),
            ])
        self.transforms = transforms.Compose(transforms_list)

    def __getitem__(self, index: int):
        image_id = self.df.iloc[index].id
        image_path = f"{self.image_dir}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.jpg"
        image = Image.open(image_path)
        image = self.transforms(image)

        if self.mode == 'test':
            return {'image':image}
        else:
            return {'image':image, 
                    'target':self.df.iloc[index].landmark_id}

    def __len__(self) -> int:
        return self.df.shape[0]

Filter out classes with too few samples (used mostly due to resource constraints):

In [None]:
class_sample_totals = pretrain_df.landmark_id.value_counts()
selected_classes = class_sample_totals[class_sample_totals >= MIN_SAMPLES_PER_CLASS]
num_classes = len(selected_classes)
train_df = pretrain_df.loc[pretrain_df.landmark_id.isin(selected_classes.index)]
print(f'Classes with at least N={MIN_SAMPLES_PER_CLASS} samples: {num_classes}')

Reset the IDs used for the landmarks back to ones starting from zero and compatible with the network outputs.

In [None]:
# map to 0-reset class indices to also index the logits of the network
label_encoder = LabelEncoder()
label_encoder.fit(train_df.landmark_id.values)
assert len(label_encoder.classes_) == num_classes
train_df.landmark_id = label_encoder.transform(train_df.landmark_id)
train_df.reset_index(drop=True, inplace=True)

Split the training samples and labels in order to be able to validate the model.

We will always reuse a fixed seed just for this particular splitting in order to make sure the validation set is never leaked into the training set depending on the order of random generator calls or other events. In this way, even if we use full randomness and shorter sessions the split between the training and validation set will remains the same and will only depend on the percentage for validation.

In [None]:
sample_indices = train_df.index.to_numpy()
# perform reprucible split without interrupting the random flow
random_state = np.random.get_state()
np.random.seed(2020)
np.random.shuffle(sample_indices)
np.random.set_state(random_state)
split = int(VALIDATION_SPLIT * len(sample_indices))
train_indices, val_indices = sample_indices[split:], sample_indices[:split]
assert len(train_indices) + len(val_indices) == len(sample_indices)
print('Train dataframe (at least N samples):', train_df.shape)
print('  for training:', len(train_indices))
print('  for validation:', len(val_indices))

For the rest of the dataset preparation we validate the test set for existing files, initialize the datasets, and validation and training samplers. The validation sampler is a normal random subset sampler over the validation indices but the training sampler is more elaborate in order to balance the sampled classes.

We use a somewhat standard approach in PyTorch through a weighted random sampler, giving zero weights for the validation indices which should never be sampled within trianing. There is also the possibilty to split the dataset itself but relying on sampling is a cleaner approach.

In [None]:
# filter non-existing test images
exists = lambda img: os.path.exists(f'{test_dir}/{img[0]}/{img[1]}/{img[2]}/{img}.jpg')
test_df = pretest_df.loc[pretest_df.id.apply(exists)]
print('Test dataframe (existing files):', test_df.shape)

train_dataset = ImageDataset(train_df, train_dir, mode='train')
test_dataset = ImageDataset(test_df, test_dir, mode='test')

val_subsampler = SubsetRandomSampler(val_indices)

# balance the classes due to high polarity in the number of samples per class
class_sample_count = train_df.landmark_id.value_counts(sort=False).sort_index().values
class_weights = 1 / torch.Tensor(class_sample_count)
sample_weights = class_weights[train_df.landmark_id.values]
# cannot draw from the validation data
sample_weights[val_indices] = 0.0
train_subsampler = WeightedRandomSampler(sample_weights, len(sample_weights))

Models that need separate classes can be defined here:

In [None]:
class EfficientNetEncoderHead(nn.Module):
    """Head for an EfficientNet encoder."""

    def __init__(self, depth, num_classes):
        super(EfficientNetEncoderHead, self).__init__()
        self.depth = depth
        self.base = efficientnet_pytorch.EfficientNet.from_name(f'efficientnet-b{self.depth}')
        #pretrained_file = glob.glob(f'../input/efficientnet-pytorch/efficientnet-b{self.depth}*')[0]
        #self.base.load_state_dict(torch.load(pretrained_file, map_location=device))
        # disable training of pre-trained layers
        #for param in self.base.parameters():
        #    param.requires_grad = False
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(self.base._fc.in_features, num_classes)

    def forward(self, x):
        """Forward method for the full model."""
        x = self.base.extract_features(x)
        x = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x = self.classifier(x)
        return x

    def numel(self, only_trainable=False):
        """
        Returns the total number of parameters used by the model
        (only counting shared parameters once).

        :param bool only_trainable: whether to only includes parameters
                                    with ``requires_grad = True``
        :returns: total number of parameters
        :rtype: int
        """
        parameters = self.parameters()
        if only_trainable:
            parameters = list(p for p in parameters if p.requires_grad)
        unique = dict((p.data_ptr(), p) for p in parameters).values()
        return sum(p.numel() for p in unique)

    def summary(self):
        """Print a summary of the loaded model."""
        print()
        print(f'model summary\n----------------------------------------------')
        for name, param in self.named_parameters():
            print(f'module {name} of shape {list(param.shape)} in the range [{torch.min(param).item():.4f},'
                f' {torch.max(param).item():.4f}]\n----------------------------------------------')
            #if "fc1" in name and "bias" in name:
            #    param.register_hook(lambda grad: self.log.add_video(f"grad/fc1", grad))
            #if "conv1" in name and "bias" in name:
            #    param.register_hook(lambda grad: self.log.add_video(f"grad/conv1", grad))
        print("Trainable parameters:", self.numel(only_trainable=True))
        print("Total parameters:", self.numel())
        print()

We can now initialize the model of choice.

In [None]:
if MODEL == "resnet":
    # load a pretrained model
    net = models.resnet50(pretrained=False)
    #net.load_state_dict(torch.load(os.path.join(model_dir, f'resnet{50}.pth'),
    #                               map_location=device))
    # disable training of pre-trained layers
    #for param in net.parameters():
    #    param.requires_grad = False
    # replace the fully connected layer with one for our classes to actually be trained
    net.fc = nn.Linear(net.fc.in_features, num_classes)

    print()
    def numel(self, only_trainable=False):
        """
        Returns the total number of parameters used by the model
        (only counting shared parameters once).

        :param bool only_trainable: whether to only includes parameters
                                    with ``requires_grad = True``
        :returns: total number of parameters
        :rtype: int
        """
        parameters = self.parameters()
        if only_trainable:
            parameters = list(p for p in parameters if p.requires_grad)
        unique = dict((p.data_ptr(), p) for p in parameters).values()
        return sum(p.numel() for p in unique)
    print("Trainable parameters:", numel(net, only_trainable=True))
    print("Total parameters:", numel(net))
    print()

elif MODEL == "efficientnet":
    net = EfficientNetEncoderHead(depth=4, num_classes=num_classes)
    net.summary()

else:
    raise ValueError(f"Inappropriate chocie for model {MODEL}, only ResNets and EfficientNets supported")

We can now load a previously saved model or warmstart it from an otherwise incompatible model.

The model is only loaded if there is a provided path and is always performed on CPU to avoid limitations from loading TPU models into CPU/GPU models and other accelerator switching. Once the model is loaded on CPU, it will eventually be sent to the device that is currently used.

The warmstarting matches all previously learned classes from a smaller model to the new wider class selection.

In [None]:
def warmstarting(model, filename, prev_min_delimiter=10):
    """Perform wormstarting of a new model from pretrained parameters of a different one."""
    print("Warmstarting the model")
    saved_state = torch.load(filename, map_location="cpu")
    net_state = net.state_dict()

    prev_class_sample_totals = pretrain_df.landmark_id.value_counts()
    prev_selected_classes = prev_class_sample_totals[prev_class_sample_totals >= prev_min_delimiter]
    prev_num_classes = len(prev_selected_classes)
    prev_train_df = pretrain_df.loc[pretrain_df.landmark_id.isin(prev_selected_classes.index)]
    print(f'Previous classes with at least N={prev_min_delimiter} samples: {prev_num_classes}')

    prev_label_encoder = LabelEncoder()
    prev_label_encoder.fit(prev_train_df.landmark_id.values)
    assert len(prev_label_encoder.classes_) == prev_num_classes

    for key in saved_state:
        if key in net_state:
            if net_state[key].shape == saved_state[key].shape:
                print(f"Restoring {key} with matching shape {saved_state[key].shape}")
                net_state[key] = saved_state[key]
                continue

            print(f"Class matching {key} from {saved_state[key].shape} to {net_state[key].shape}")
            # assert that the new model is larger than the old model
            assert len(net_state[key].shape) == len(saved_state[key].shape) <= 2
            assert net_state[key].shape[0] >= saved_state[key].shape[0]

            saved_idx = np.arange(saved_state[key].shape[0])
            net_idx = label_encoder.transform(prev_label_encoder.inverse_transform(saved_idx))
            for i, j in zip(net_idx, saved_idx):
                net_state[key][i] = saved_state[key][j]

if model_load_file != "":
    if not os.path.exists(model_load_file):
        print("Warning: Previous model file was not found, creating a new one")
    else:
        if WARMSTARTING > 0:
            warmstarting(net, model_load_file, WARMSTARTING)
        else:
            net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
net.eval()

What follows is some helper functions for average meters and GAP evaluating function.

In [None]:
class AverageMeter:
    ''' Computes and stores the average and current value '''
    def __init__(self) -> None:
        self.reset()

    def reset(self) -> None:
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0

    def update(self, val: float, n: int = 1) -> None:
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def GAP(y_true, y_pred):
    """Compute Global Average Precision score (GAP)"""
    indexes = list(y_pred.keys())
    indexes.sort(
        key=lambda x: -y_pred[x][1],
    )
    queries_with_target = len([i for i in y_true.values() if i is not None])
    correct_predictions = 0
    total_score = 0.
    for i, k in enumerate(indexes, 1):
        relevance_of_prediction_i = 0
        if y_true[k] == y_pred[k][0]:
            correct_predictions += 1
            relevance_of_prediction_i = 1
        precision_at_rank_i = correct_predictions / i
        total_score += precision_at_rank_i * relevance_of_prediction_i
    return 1 / queries_with_target * total_score

We are now ready to defien the training step which also has some conditionals for the not yet fully PyTorch integrated TPU accelerators.

In [None]:
def train_step(train_loader, model, device, criterion, optimizer, epoch, lr_scheduler, filename):
    print(f'Epoch {epoch} with total batches: {len(train_loader)}')
    batch_time = AverageMeter()
    losses = AverageMeter()
    avg_score = AverageMeter()

    model.train()

    end = time.time()
    lr = None

    for i, data in enumerate(train_loader):
        x = data['image']
        y = data['target']

        y_ = model(x.to(device))
        loss = criterion(y_, y.to(device))

        confs, preds = torch.max(y_.detach(), dim=1)
        tuple_pred = list(zip(preds.cpu().numpy(), confs.cpu().numpy()))
        true_labels = y.cpu().numpy()
        y_true, y_pred = {}, {}
        for j in range(len(tuple_pred)):
            y_true[f'{j}'] = true_labels[j]
            y_pred[f'{j}'] = tuple_pred[j]
        avg_score.update(GAP(y_true, y_pred))
        losses.update(loss.data.item(), x.size(0))

        optimizer.zero_grad()
        loss.backward()
        if TPU_CORES > 0:
            xm.optimizer_step(optimizer, barrier=True)
        else:
            optimizer.step()
        lr_scheduler.step()
        lr = optimizer.param_groups[0]['lr']

        batch_time.update(time.time() - end)
        end = time.time()

        if i % LOG_FREQ == 0:
            print(f'{device}/{epoch} [{i}/{len(train_loader)}]\t'
                    f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'loss {losses.val:.4f} ({losses.avg:.4f})\t'
                    f'GAP {avg_score.val:.4f} ({avg_score.avg:.4f})\t'
                    f'-> rate: {lr}')
        if i % SAVE_FREQ == 0:
            if filename != "":
                if TPU_CORES > 0:
                    xm.save(model.state_dict(), filename)
                else:
                    torch.save(model.state_dict(), filename)
                torch.save(optimizer.state_dict(), filename + "_opt.tar")

    print(f'Average GAP on train: {avg_score.avg:.4f}')

Next are the evaluation step and submission generation where the evaluation step is a bit more compact using `tqdm` updates. The training step in comparison keeps each line in order to trace the changes in the loss and GAP accuracy.

In [None]:
def eval_step(data_loader, model, device, mode="validation"):
    avg_score = AverageMeter()
    model.eval()

    softmax = nn.Softmax(dim=1)
    all_preds, all_confs, all_targets = [], [], []

    with torch.no_grad():
        with tqdm(total=len(data_loader), ncols=100,
                  bar_format='{l_bar}{bar}| {postfix[0]} {n_fmt}/{total_fmt} ({postfix[3]}) '\
                             '[{elapsed}<{remaining}, {rate_fmt}, {postfix[1]}:{postfix[2][GAP]:>.4}]',
                  postfix=["Batch", "GAP", dict(GAP="?"), device]) as t:
            for i, data in enumerate(data_loader):
                # we would check `data_loader.dataset.mode` but the validation set is
                # a subset of the training set, thus having training mode
                if mode != 'test':
                    x, y = data['image'], data['target']
                else:
                    x, y = data['image'], None

                y_ = model(x.to(device))
                y_ = softmax(y_)

                confs, preds = torch.topk(y_, NUM_TOP_PREDICTS)

                if y is not None:
                    tuple_pred = list(zip(preds.cpu().numpy(), confs.cpu().numpy()))
                    true_labels = y.cpu().numpy()
                    y_true, y_pred = {}, {}
                    for j in range(len(tuple_pred)):
                        y_true[f'{j}'] = true_labels[j]
                        y_pred[f'{j}'] = tuple_pred[j]
                    avg_score.update(GAP(y_true, y_pred))
                    t.postfix[2]["GAP"] = avg_score.avg

                all_confs.append(confs)
                all_preds.append(preds)
                if y is not None:
                    all_targets.append(y)

                t.update()

        if mode != 'test':
            print(f'Average GAP: {avg_score.avg:.4f}')

    preds = torch.cat(all_preds)
    confs = torch.cat(all_confs)
    targets = torch.cat(all_targets) if len(all_targets) else None

    return preds, confs, targets

def generate_submission(test_loader, model, device, label_encoder):
    sample_sub = pd.read_csv('../input/landmark-recognition-2020/sample_submission.csv')

    predicts_gpu, confs_gpu, _ = eval_step(test_loader, model, device, mode="test")
    predicts, confs = predicts_gpu.cpu().numpy(), confs_gpu.cpu().numpy()

    labels = [label_encoder.inverse_transform(pred) for pred in predicts]

    sub = test_loader.dataset.df
    def concat(label: np.ndarray, conf: np.ndarray) -> str:
        return ' '.join([f'{L} {c}' if c > MIN_CONF else '' for L, c in zip(label, conf)])
    sub['landmarks'] = [concat(label, conf) for label, conf in zip(labels, confs)]

    sample_sub = sample_sub.set_index('id')
    sub = sub.set_index('id')
    sample_sub.update(sub)
    sample_sub.to_csv('submission.csv')

    print(sample_sub.head())

The final step to combine all previous ones is to train, test, and validate on a single (if CPU, GPU, or one TPU core is used) or multiple replicas (in the case of all 8 TPU cores). While the number of TPU cores should still be > 0 to use a TPU and only 1 or 8 cores are currently supported.

The rest of this takes care to instantiate the correct data loaders (possibly with distributed samplers on top of the previous samplers), reset or reload a previous optimizer (e.g. from previous training session), and run the selected stages possibly excluding training, validation, or testing.

In [None]:
def map_run(index):
    device = xm.xla_device() if TPU_CORES > 0 else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if ENABLE_FAST_SKIP and test_df.id[0] == "00084cdf8f600d00":
        # This is a run on the public data, skip it to speed up submission run on private data.
        print("Skipping run on public test set.")
        sample_sub = pd.read_csv('../input/landmark-recognition-2020/sample_submission.csv')
        sample_sub.to_csv('submission.csv')
        return

    if TPU_CORES > 1:
        # TODO: we cannot subsample in an elegant way until PyTorch adds a wrapper with a subsampler
        #train_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
        #val_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        train_sampler = DistributedSamplerWrapper(train_subsampler, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
        val_sampler = DistributedSamplerWrapper(val_subsampler, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

        world_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                        sampler=train_sampler, num_workers=CPU_CORES, drop_last=True)
        world_val_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                      sampler=val_sampler, num_workers=CPU_CORES, drop_last=True)

        para_train_loader = pl.ParallelLoader(world_train_loader, [device])
        para_val_loader = pl.ParallelLoader(world_val_loader, [device])

        train_loader = para_train_loader.per_device_loader(device)
        val_loader = para_val_loader.per_device_loader(device)
    else:
        train_sampler = train_subsampler
        val_sampler = val_subsampler

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                  sampler=train_sampler, num_workers=CPU_CORES, drop_last=True)
        val_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                sampler=val_sampler, num_workers=CPU_CORES, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=CPU_CORES)

    net.to(device)

    if "train" in KERNEL_STAGES:
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-3, weight_decay=1e-4)
        if os.path.exists(model_load_file + "_opt.tar"):
            optimizer.load_state_dict(torch.load(model_load_file + "_opt.tar", map_location="cpu"))
        else:
            print("Warning: Optimizer parameters not reloaded, training will be reset")
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*NUM_EPOCHS, eta_min=1e-6)

        for epoch in range(1, NUM_EPOCHS + 1):
            print('-' * 50)
            time.sleep(1)
            train_step(train_loader, net, device, criterion, optimizer, epoch, scheduler, model_save_file)

    if "validate" in KERNEL_STAGES:
        print('-' * 50)
        time.sleep(1)
        eval_step(val_loader, net, device, mode="validation")
    if "test" in KERNEL_STAGES and TPU_CORES == 0:
        print('-' * 50)
        time.sleep(1)
        generate_submission(test_loader, net, device, label_encoder)

if TPU_CORES > 1:
    xmp.spawn(map_run, args=(), nprocs=TPU_CORES, start_method='fork')
else:
    map_run(0)

I hope this was useful, if you have any recommendations for improvements or questions I will be happy to hear!