https://www.kaggle.com/docs/tpu

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7
!pip install timm

import os
# For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from pprint import pprint
from tqdm.notebook import tqdm
import os, cv2, glob, time, random, gc
from datetime import datetime

from sklearn.metrics import accuracy_score
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler

import timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, distributed

import albumentations as A
from albumentations.pytorch import ToTensorV2

# https://pytorch.org/xla/release/1.7/index.html
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.serialization as xser

import ignite.distributed as igd

In [None]:
TRAIN_DIR = '../input/cassava-leaf-disease-classification/train_images/'
LABEL_JSON_PATH = '../input/cassava-leaf-disease-classification/label_num_to_disease_map.json'
TEST_DIR = '../input/cassava-leaf-disease-classification/test_images'

df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
df_test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
label_map = pd.read_json('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', orient='index')
display(label_map)

In [None]:
# get the list of pretrained models
model_names = timm.list_models(pretrained=True)
pprint(model_names)

In [None]:
IMG_SIZE = 384
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
BATCH_SIZE = 8
ITER_FREQ = 100
EPOCHS = 2
NUM_WORKERS = 4
MODEL_ARCH = 'tf_efficientnet_b4_ns'
LR = 1e-5
WEIGHT_DECAY = 1e-6
ITERS_TO_ACCUMULATE = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/multi-core-alexnet-fashion-mnist.ipynb
# Do not initialize device type here.
# device = xm.xla_device()

In [None]:
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=1111)

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

def init_logger(log_file=OUTPUT_DIR+'train.log'):

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

In [None]:
class AverageMeter(object):
    
    # Keeps track of most recent, average, sum, and count of a metric.
    
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class CLD_train_ds(Dataset):
    
    def __init__(self, df, transform=None):
        self.df = df
        self.image_id = df['image_id'].values
        self.label = df['label'].values
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.df)
        
    def __getitem__(self, idx):
        
        image_id = self.image_id[idx]
        img = cv2.imread(TRAIN_DIR + image_id)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = img / 255.0
        
        if self.transform:
            trans_img = self.transform(image=img)
            img = trans_img['image']
            
        label = torch.tensor(self.label[idx]).long()
        
        return img, label

In [None]:
def get_transform(*, train=True):
    
    if train:
        return A.Compose([
            A.RandomResizedCrop(IMG_SIZE, IMG_SIZE),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.Normalize(mean=MEAN, std=STD),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.CenterCrop(IMG_SIZE, IMG_SIZE),
            A.Resize(IMG_SIZE, IMG_SIZE),
            A.Normalize(mean=MEAN, std=STD),
            ToTensorV2(),
        ])

In [None]:
class eff_net(nn.Module):
    
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained, n_class)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x
# class resnext(nn.Module):
#     def __init__(self, model_name, n_class, pretrained=False):
#         super().__init__()
#         self.model = timm.create_model(model_name, pretrained=pretrained)
#         n_features = self.model.fc.in_features
#         self.model.fc = nn.Linear(n_features, n_class)

#     def forward(self, x):
#         x = self.model(x)
#         return x

In [None]:
def train_fn(model, dataloader, epoch, optimizer, device, criterion):
    
    model.train()
    
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    accuracies = AverageMeter()
    
    start = time.time()
#     epoch_loss = 0
#     epoch_accuracy = 0
    global_step = 0
    
    loader = tqdm(dataloader, total=len(dataloader))
    for step, (images, labels) in enumerate(loader):
        
        images = images.to(device, dtype=torch.float32)
        labels = labels.to(device, dtype=torch.int64)
        
        data_time.update(time.time() - start)
        
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        
        output = model(images)
        loss = criterion(output, labels)
        loss = loss / ITERS_TO_ACCUMULATE
        losses.update(loss.item(), BATCH_SIZE)
        
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
    
        # Calculate Accuracy
        accuracy = (output.argmax(dim=1) == labels).float().mean()
        accuracies.update(accuracy.item(), BATCH_SIZE)
        
        if (step+1) % ITERS_TO_ACCUMULATE == 0:
            
            # Run the provided optimizer step and issue the XLA device step computation.
            xm.optimizer_step(optimizer)    
            global_step += 1
            
        batch_time.update(time.time() - start)
        start = time.time()

#         epoch_loss += loss
#         epoch_accuracy += accuracy
        
        if step % ITER_FREQ == 0:
            
            xm.master_print('Epoch: [{0}][{1}/{2}]\t'
                            'Batch Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s), '
                            'Data Time {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'
                            'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                            'Accuracy {accuracy.val:.4f} ({accuracy.avg:.4f})'.format((epoch+1),
                                                                                      step, len(dataloader),
                                                                                      batch_time=batch_time,
                                                                                      data_time=data_time,
                                                                                      loss=losses,
                                                                                      accuracy=accuracies))
        # To check the loss real-time while iterating over data.
        loader.set_postfix(loss=losses.avg, accuracy=accuracies.avg)
        
#         del images, labels
        
    return losses.avg, accuracies.avg

In [None]:
def engine():
    
    train_data = CLD_train_ds(df, transform = get_transform())
    
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
    train_sampler = distributed.DistributedSampler(train_data,
                                                   # .xrt_world_size() retrieves the number of devices(cores)
                                                   # which is taking part of the replication.
                                                   num_replicas=xm.xrt_world_size(),
                                                   # .get_ordinal() retrieves the replication ordinal of the 
                                                   # current process. The ordinals range from 0 to xrt_world_size()-1
                                                   rank=xm.get_ordinal(),
                                                   shuffle=True)
    
    train_loader = DataLoader(train_data,
                              batch_size=BATCH_SIZE,
                              sampler=train_sampler,
                              num_workers=NUM_WORKERS,
                              drop_last=True,
                              pin_memory=True # enables faster data transfer to CUDA-enabled GPUs.
                             )
    
    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()
    model = eff_net(MODEL_ARCH, 5, True)
    xm.set_rng_state(1111, device)
    model.to(device)
    
    params = filter(lambda p: p.requires_grad, model.parameters())
    lr = LR * xm.xrt_world_size()  # Number of cores taking part in replication.
    optimizer = torch.optim.Adam(params, lr=lr, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss().to(device)
    
    loss = []
    accuracy = []
    
    xm.master_print(f"Initializing training on {xm.xrt_world_size()} TPU cores.")
    start_time = datetime.now()
    xm.master_print(f"Start Time: {start_time}")
    
    for epoch in range(EPOCHS):
        
        gc.collect()
        # ParallelLoader wraps an existing PyTorch DataLoader with background data upload.
        pl_loader = pl.ParallelLoader(train_loader, [device])
        # we pass a parelleloader and not a dataloader (for parellel training only).
        
        epoch_start = time.time()
        avg_loss, avg_accuracy = train_fn(model,
                                          # .per_device_loader() gets the loader
                                          # iterator object for the given device.
                                          pl_loader.per_device_loader(device),
                                          epoch, optimizer, device, criterion)
        epoch_end = time.time() - epoch_start
        loss.append(avg_loss)
        accuracy.append(avg_accuracy)

        content = f'Epoch {epoch+1} - avg_loss: {avg_loss:.4f} avg_accuracy: {avg_accuracy:.4f} time: {epoch_end:.0f}s'
        with open(f'TPU_log_{MODEL_ARCH}.txt', 'a') as appender:
            appender.write(content + '\n')

        # LOGGER.info(f'Epoch {epoch+1} - avg_loss: {avg_loss:.4f} avg_accuracy: {avg_accuracy:.4f} time: {epoch_end:.0f}s')
        # Save the model to use it for inference.
        xm.save(model.state_dict(), f'TPU_{MODEL_ARCH}_epoch_{(epoch+1)}.pth')
        xm.save(model, f'TPU_{MODEL_ARCH}_epoch_{(epoch+1)}')
        
#         xser.save(model.state_dict(), f'TPU_XSER_{MODEL_ARCH}_epoch_{(epoch+1)}', master_only=False)
        
        gc.collect()
        del pl_loader
        torch.cuda.empty_cache()
    
    xm.master_print(f"Execution time: {datetime.now() - start_time}")
    
    return {'loss':loss, 'accuracy':accuracy}

In [None]:
# def main():
#     logs = engine()

def _mp_fn(rank, flags):
    # Sets the default torch.Tensor type to given floating point type tensor. 
    torch.set_default_tensor_type("torch.FloatTensor")
    logs = engine()
    
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

In [None]:
logs