Here is an attempt to do KFold, Parallel training. However training happens only on one core at a time. 

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!pip uninstall -q typing --yes

In [None]:
!pip install pytorch-lightning
!pip install timm

In [None]:
import os
import cv2
import pandas as pd
import numpy as np
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
# from torchvision import models
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
# from efficientnet_pytorch import EfficientNet
import torch_xla.core.xla_model as xm
import torch_xla
import timm

In [None]:
import random
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [None]:
TRAIN_CSV = "../input/cassava-leaf-disease-classification/train.csv"
TRAIN_IMAGE_FOLDER = '../input/cassava-leaf-disease-classification/train_images'
CLASSES = 5

### Hyper parameters

In [None]:
FOLDS = 8
BATCH_SIZE = 8
LR = 0.01
EPOCHS=2
LOSS_FUNCTION = nn.BCEWithLogitsLoss()
IMG_SIZE = 128
EARLY_STOPPING = True
MODEL_ARCH = 'resnet50'

### Dataset

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, train, train_mode=True, transforms=None):
        self.train = train
        self.transforms = transforms
        self.train_mode = train_mode
    
    def __len__(self):
        return self.train.shape[0]
    
    def __getitem__(self, index):
        image_path = os.path.join(TRAIN_IMAGE_FOLDER, self.train.iloc[index].image_id)
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if (self.transforms):
            image = self.transforms(image=image)["image"]
        
        if not(self.train_mode):
            return {"x":image}
        
        return {
            "x": image,
            "y": torch.tensor(self.train.iloc[index, self.train.columns.str.startswith('label')], dtype=torch.float64)
        }

### Transforms

In [None]:
def get_augmentations():
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)    
    
    train_augmentations = albu.Compose([
        albu.RandomResizedCrop(IMG_SIZE, IMG_SIZE, p=1.0),
        albu.Transpose(p=0.5),
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(0.5),
        albu.Normalize(mean, std, max_pixel_value=255, always_apply=True),        
        ToTensorV2(p=1.0)
    ], p=1.0)
    
    valid_augmentations = albu.Compose([
        albu.Normalize(mean, std, max_pixel_value=255, always_apply=True),        
        ToTensorV2(p=1.0)
    ], p=1.0)   
    
    return train_augmentations, valid_augmentations

train_augs, val_augs = get_augmentations()

### NN Model

In [None]:
# # These are the available model architectures in timm
# from pprint import pprint
# model_names = timm.list_models(pretrained=True)
# pprint(model_names)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(MODEL_ARCH, pretrained=True)
#         self.model = base_model

#         # Efficientnets
#         n_features = self.model.classifier.in_features
#         self.model.classifier = nn.Linear(n_features, CLASSES)
        
        # Resnets
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, CLASSES)
        
    def forward(self, x):
        x = self.model(x)
        return x

### K-Fold CV

In [None]:
traincsv = pd.read_csv(TRAIN_CSV)
traincsv['kfold'] = -1
traincsv = traincsv.sample(frac=1).reset_index(drop=True)
stratifier = StratifiedKFold(n_splits=FOLDS)

for fold, (train_index, val_index) in enumerate(stratifier.split(X=traincsv.image_id.values, y=traincsv.label.values)):
    traincsv.loc[val_index, "kfold"] = fold

traincsv.to_csv("train_folds.csv", index=False)

### PL Data module

In [None]:
class CassavaDataModule(pl.LightningDataModule):
    def __init__(self, fold):
        super().__init__()
        self.train_aug, self.valid_aug = get_augmentations()
        self.fold = fold
        self.batch_size = BATCH_SIZE
    
    def setup(self, stage=None):
        folds = pd.read_csv('./train_folds.csv')
        folds = pd.get_dummies(folds, columns=['label'])
        train_fold = folds.loc[folds["kfold"] != self.fold]
        val_fold = folds.loc[folds["kfold"] == self.fold]
        
        self.train_ds = CassavaDataset(train_fold, transforms=train_augs)
        self.val_ds = CassavaDataset(val_fold, transforms=val_augs)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, self.batch_size, num_workers=4, shuffle=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, self.batch_size, num_workers=4, shuffle=False)        
        

### Callbacks

In [None]:
early_stopping = EarlyStopping('val_accuracy', patience=3, mode='max')

callbacks=[]

if EARLY_STOPPING == True:
    callbacks.append(early_stopping)

### PL Module

In [None]:
class CassavaPLModule(pl.LightningModule):
    def __init__(self, hparams, model):
        super(CassavaPLModule, self).__init__()
        self.hparams = hparams
        self.model = model
        self.criterion = LOSS_FUNCTION
        self.accuracy = pl.metrics.Accuracy()
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr)
        scheduler = {
            'scheduler': 
                torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, patience=3,
                    threshold=0.001,
                    mode='min', verbose=True
                ),
            'interval': 'epoch',
            'monitor' : 'val_loss'
        }
        return [optimizer], [scheduler]
        
    def training_step(self, batch, batch_index):
        # One batch at a time
        features = batch['x']
        targets = batch['y']
        out = self(features)
        loss = self.criterion(out, targets)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)      
        metric_acc = self.accuracy(out, targets)
        self.log("train_accuracy", metric_acc, on_step=True, on_epoch=True, prog_bar=True,logger=True)
        
    def validation_step(self, batch, batch_index):
        # One batch at a time
        features = batch['x']
        targets = batch['y']
        out = self(features)
        loss = self.criterion(out, targets)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 
        metric_acc = self.accuracy(out, targets)
        self.log("val_accuracy", metric_acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)        

### Training

In [None]:
def train(fold):   
    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints/',
        filename='model_{fold}-{val_loss:.2f}',
        monitor='val_loss', verbose=True,
        save_last=False, save_top_k=1, save_weights_only=False,
        mode='min', period=1, prefix=''
    )        
    
    tpu_core = fold + 1
    
    trainer = pl.Trainer(
#                         gpus=-1 if torch.cuda.is_available() else None, 
                        tpu_cores=[tpu_core],
#                         precision=16 if torch.cuda.is_available() else 32,
                        precision=16,
#                         plugins='ddp_sharded',
                        max_epochs=EPOCHS,
                        checkpoint_callback=checkpoint_callback,
                        callbacks=callbacks)
    model = Model()
    pl_dm = CassavaDataModule(fold=fold)
    pl_module = CassavaPLModule(hparams={'lr':LR, 'batch_size':BATCH_SIZE}, model=model)
    
    trainer.use_native_amp = False
    trainer.fit(pl_module, pl_dm)
    
    print(checkpoint_callback.best_model_path, checkpoint_callback.best_model_score)
    

In [None]:
import joblib as jl
parallel = jl.Parallel(n_jobs=FOLDS, backend='threading', batch_size=1)
parallel(jl.delayed(train)(i) for i in range(FOLDS))