This notebook uses pytorch xla package to implement the training and validation of effb7 model with tpu. Some settings refer to [Alien's notebook](https://www.kaggle.com/h053473666/siim-covid19-efnb7-train-study/data). You can also see some differences between pytorch tpu code and [tensorflow tpu code](https://www.kaggle.com/h053473666/siim-covid19-efnb7-train-study/data).

For more about PyTorch/XLA please see its [Github](https://github.com/pytorch/xla) or its [documentation](http://pytorch.org/xla/).

## Kaggle or Colab
You can comment out one of below two cells for running code on kaggle or [colab](https://drive.google.com/drive/my-drive).

In [None]:
# Run this code with kaggle TPU
ENVIRONMENT = "kaggle"
dir = "/kaggle/working"

In [None]:
# # Run this code with colab TPU
# # You can use larger batch_size, image_size and model with colab high memory mode
# ENVIRONMENT = "colab"
# dir = "/content/gdrive/MyDrive/siim-effnets-classification-train"  # directory of this code

# from google.colab import drive
# drive.mount('/content/gdrive/')

# !pip install kaggle > /dev/null 2>&1

# # Upload your kaggle.json file to connect with your kaggle account
# from google.colab import files
# files.upload()  
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

# # Download specific kaggle dataset which contains .npy data processed from original .dcm data 
# !mkdir -p '/content/siim-dicom-to-npy-v1'  
# !kaggle datasets download -d shangweichen/siim-dicom-to-npy-v1 -p '../content/gdrive/MyDrive'
# !unzip -n '/content/gdrive/MyDrive/siim-dicom-to-npy-v1.zip' -d '/content/siim-dicom-to-npy-v1' > /dev/null 2>&1
# !rm '/content/gdrive/MyDrive/siim-dicom-to-npy-v1.zip'

# # Download timm package
# !mkdir -p '/content/timm-pytorch-image-models'
# !kaggle datasets download -d kozodoi/timm-pytorch-image-models -p '../content/gdrive/MyDrive'
# !unzip -n '/content/gdrive/MyDrive/timm-pytorch-image-models.zip' -d '/content/timm-pytorch-image-models' > /dev/null 2>&1
# !rm '/content/gdrive/MyDrive/timm-pytorch-image-models.zip'

# !pip install pydicom > /dev/null 2>&1
# !pip install --upgrade albumentations > /dev/null 2>&1

In [None]:
if ENVIRONMENT not in ["kaggle", "colab"]:
    raise ValueError("ENVIRONMENT Wrong!")

## Import

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null 2>&1
!python pytorch-xla-env-setup.py --version 20210331 --apt-packages libomp5 libopenblas-dev > /dev/null 2>&1

In [None]:
import sys
if ENVIRONMENT == "kaggle":
    sys.path.append("/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master")
elif ENVIRONMENT == "colab":
    sys.path.append("/content/timm-pytorch-image-models/pytorch-image-models-master")

import platform
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm
import cv2
import pydicom
import random
import glob
import gc
from math import ceil
import albumentations as A
import matplotlib.pyplot as plt
from pydicom.pixel_data_handlers.util import apply_voi_lut
from sklearn.metrics import roc_auc_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch
import timm
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, StepLR
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.simplefilter('ignore')
np.set_printoptions(suppress=True)
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
if ENVIRONMENT == "kaggle":
    print(np.load(f"/kaggle/input/siim-dicom-to-npy-v1/trainset/000a312787f2.npy").shape)
elif ENVIRONMENT == "colab":
    print(np.load(f"/content/siim-dicom-to-npy-v1/trainset/000a312787f2.npy").shape)

## Necessary Settings
If choose effb7 model, gradient will explode if set lr to 1e-3/8, and batch size cannot reach 16 * 8 otherwise will raise OOM(out of memory) error, this is the most intuitive difference between this notebook and [tensorflow tpu code](https://www.kaggle.com/h053473666/siim-covid19-efnb7-train-study/data).

In [None]:
class Config:
    train_pcent = 0.80
    model_name = 'tf_efficientnet_b7'
    image_size = (400, 400)
    batch_size = 8 * 8
    epochs = 20
    seed = 2021
    lr = 1e-4 / 8  
    workers = 8
    drop_last = True
    augments = A.Compose([
                   A.augmentations.crops.transforms.RandomResizedCrop(height=image_size[1], 
                                                                      width=image_size[0], 
                                                                      scale=(0.88*0.88, 1), 
                                                                      ratio=(0.8, 1.2), 
                                                                      p=0.5),
                   A.augmentations.transforms.HorizontalFlip(p=0.5),
                   A.augmentations.transforms.VerticalFlip(p=0.5),
                   A.augmentations.geometric.rotate.Rotate(p=0.5),
                   A.OneOf([
                       A.augmentations.transforms.Blur(),
                       A.augmentations.transforms.GlassBlur(),
                       A.augmentations.transforms.GaussianBlur(),
                       A.augmentations.transforms.GaussNoise(),
                       A.augmentations.transforms.RandomGamma(),
                       A.augmentations.transforms.InvertImg(),
                       A.augmentations.transforms.RandomFog()
                   ], p=0.5)
    ])

    def get_loss_fn():
        return nn.CrossEntropyLoss()

    def get_optimizer(model, learning_rate):
        return torch.optim.Adam(model.parameters(), lr=learning_rate)

    def get_scheduler(optimizer):
        return ReduceLROnPlateau(optimizer, 
                                 mode='min', 
                                 factor=0.1, 
                                 patience=5, 
                                 verbose=False, 
                                 min_lr=1e-5)
    

# Make results reproducible
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
seed_everything(Config.seed)

## Data Preprocessing

In [None]:
# Read file
if ENVIRONMENT == "kaggle":
    train_image = pd.read_csv(f"/kaggle/input/siim-covid19-detection/train_image_level.csv")
    train_study = pd.read_csv(f"/kaggle/input/siim-covid19-detection/train_study_level.csv")
elif ENVIRONMENT == "colab":
    train_image = pd.read_csv(f"{dir}/SIIM-FISABIO-RSNA COVID-19 Detection/train_image_level.csv")
    train_study = pd.read_csv(f"{dir}/SIIM-FISABIO-RSNA COVID-19 Detection/train_study_level.csv")

# Merge study_level and image_level
train_study['StudyInstanceUID'] = train_study['id'].apply(lambda x: x.replace('_study', ''))
train_study.drop(['id'], axis=1, inplace=True)
train = train_image.merge(train_study, on='StudyInstanceUID', how='left')
train["class"] = np.argmax(train[['Negative for Pneumonia', 'Typical Appearance',
                  'Indeterminate Appearance', 'Atypical Appearance']].values, axis=1)

train.head()

In [None]:
class SIIMData(Dataset):
    def __init__(self, df, augments=True):
        super().__init__()
        self.df = df.sample(frac=1).reset_index(drop=True)
        
        if augments:
            self.augments = Config.augments
        else:
            self.augments = None
        
    def __getitem__(self, idx):
        index = self.df.loc[idx, "id"].split("_")[0]
        
        if ENVIRONMENT == "kaggle":
            image = np.load(f"/kaggle/input/siim-dicom-to-npy-v1/trainset/{index}.npy")
        elif ENVIRONMENT == "colab":
            image = np.load(f"/content/siim-dicom-to-npy-v1/trainset/{index}.npy")
            
        if self.augments:
            image = torch.from_numpy(self.augments(image=image)['image'])
        else:
            image = torch.tensor(image)
            
        image = image.permute(2, 0, 1)
        label = torch.tensor(self.df.loc[idx]["class"])
        return image / 255.0, label
    
    def __len__(self):
        return len(self.df)

In [None]:
train_data, valid_data = train_test_split(train, test_size=1-Config.train_pcent, 
                       stratify=train["class"].values, random_state=Config.seed)
print(f"Training on {train_data.shape[0]} samples and Validation on {valid_data.shape[0]} samples")

train_set = SIIMData(df=train_data, augments=True)
valid_set = SIIMData(df=valid_data, augments=False)

## Model

In [None]:
class EfficientNetModel(nn.Module):
    """
    Model Class for EfficientNet Model
    """
    def __init__(self, num_classes=4, model_name=Config.model_name, pretrained=True):
        super(EfficientNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=3)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    

class NFNetModel(nn.Module):
    """
    Model Class for EfficientNet Model
    """
    def __init__(self, num_classes=4, model_name=Config.model_name, pretrained=True):
        super(NFNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=3)
        self.model.head.fc = nn.Linear(self.model.head.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
if "efficient" in Config.model_name or "eff" in Config.model_name: 
    model_ = EfficientNetModel()
elif "nfnet" in Config.model_name:
    model_ = NFNetModel()
else:
    raise RuntimeError("Must specify a valid model type to train.")
print(f"Training Model: {Config.model_name}")

## Train & Valid

In [None]:
class Record:
    '''
    Records labels and predictions within one epoch
    '''
    def __init__(self):
        self.labels = []
        self.preds = []
        
    def update(self, cur_labels, cur_logits):
        cur_labels = cur_labels.detach().cpu().numpy()
        cur_logits = np.exp(cur_logits.detach().cpu().numpy())
        cur_preds = cur_logits / np.sum(cur_logits, axis=1, keepdims=True)
        self.labels.append(cur_labels)
        self.preds.append(cur_preds)

    def get_labels(self):
        return np.concatenate(self.labels) # (n, )

    def get_preds(self):
        return np.concatenate(self.preds, axis=0) # (n, 4)

    @staticmethod
    def get_acc(confusion_mat):
        return round(np.sum(np.eye(4) * confusion_mat) / np.sum(confusion_mat) * 100, 2)
    
    @staticmethod
    def get_auc(labels, preds):
        return round(roc_auc_score(labels, preds, average='weighted', multi_class='ovr'), 2)


class Trainer:
    def __init__(self, model, optimizer, loss_fn, device):
        """
        Constructor for Trainer class
        """
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device
    
    def train_one_cycle(self, train_loader):
        """
        Runs one epoch of training, backpropagation and optimization
        """
        self.model.train()
        total_loss = 0
        total_nums = 0
        record = Record()

        for idx, (xtrain, ytrain) in enumerate(train_loader):
            xtrain = xtrain.to(self.device, dtype=torch.float)
            ytrain = ytrain.to(self.device, dtype=torch.long)

            self.optimizer.zero_grad()
            outputs = self.model(xtrain)
            loss = self.loss_fn(outputs, ytrain)
            
            total_loss += (loss.detach().item() * ytrain.size(0))
            total_nums += ytrain.size(0)
            record.update(ytrain, outputs)
            
            loss.backward()
            # The step() function now not only propagates gradients, but uses the Cloud TPU context 
            # to synchronize gradient updates across each processes' copy of the network. 
            # This ensures that each processes' network copy stays "in sync" (they are all identical).
            xm.optimizer_step(self.optimizer)
            
        self.model.eval()
        return total_loss / total_nums, record.get_labels(), record.get_preds()

    def valid_one_cycle(self, valid_loader):
        """
        Runs one epoch of prediction
        """
        self.model.eval()
        total_loss = 0
        total_nums = 0
        record = Record()
        
        for idx, (xval, yval) in enumerate(valid_loader):
            with torch.no_grad():
                xval = xval.to(self.device, dtype=torch.float)
                yval = yval.to(self.device, dtype=torch.long)
        
                outputs = self.model(xval)
                loss = self.loss_fn(outputs, yval)

                total_loss += (loss.detach().item() * yval.size(0))
                total_nums += yval.size(0)
                record.update(yval, outputs)
        
        return total_loss / total_nums, record.get_labels(), record.get_preds()

In [None]:
def _mp_fn(rank, flags):
    '''
    Train and valid
    '''
    torch.set_default_tensor_type('torch.FloatTensor')

    # Sets a common random seed both for initialization and ensuring graph is the same
    torch.manual_seed(Config.seed)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()
    
    # load the model into each tpu core
    model = model_.to(device)
    
    # Creates the (distributed) train sampler
    # which let this process only access its portion of the training dataset.  
    train_sampler = DistributedSampler(
        train_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=int(Config.batch_size/xm.xrt_world_size()),
        sampler=train_sampler,
        drop_last=Config.drop_last,
        num_workers=Config.workers,
    )
    valid_sampler = DistributedSampler(
        valid_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=int(Config.batch_size/xm.xrt_world_size()),
        sampler=valid_sampler,
        drop_last=Config.drop_last,
        num_workers=Config.workers,
    )

    optimizer = Config.get_optimizer(model, Config.lr * xm.xrt_world_size())
    loss_fn = Config.get_loss_fn()
    scheduler = Config.get_scheduler(optimizer)
    
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    if rank == 0:
        train_losses = []
        valid_losses = []
        train_accs = []
        valid_accs = []
        train_aucs = []
        valid_aucs = []
        lr_scheduler = []
    
    gc.collect()
    for epoch in range(Config.epochs):
        xm.master_print(f"{'-'*30} EPOCH: {epoch+1}/{Config.epochs} {'-'*30}")
        
        # Run one training epoch
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loss, train_labels, train_preds = trainer.train_one_cycle(para_loader.per_device_loader(device))

        # Compute training metrics
        train_loss_avg = xm.mesh_reduce('train_loss_reduce', train_loss, lambda alist: sum(alist) / len(alist))
        train_labels_concat = xm.mesh_reduce('train_labels_concat', train_labels, lambda alist: np.concatenate(alist))  # (8n, )
        train_preds_concat = xm.mesh_reduce('train_preds_concat', train_preds, lambda alist: np.concatenate(alist, axis=0))  # (8n, 4)
        train_confusion_mat = confusion_matrix(train_labels_concat, np.argmax(train_preds_concat, axis=1))
        train_acc = Record.get_acc(train_confusion_mat)
        train_auc = Record.get_auc(train_labels_concat, train_preds_concat)
        xm.master_print(f"Train Loss: {train_loss_avg:.4f}  Train Acc: {train_acc}%  Train AUC: {train_auc}")
        xm.master_print(train_confusion_mat)

        # Run one validation epoch
        para_loader = pl.ParallelLoader(valid_loader, [device])
        valid_loss, valid_labels, valid_preds = trainer.valid_one_cycle(para_loader.per_device_loader(device))
        
        # Compute validation metrics
        valid_loss_avg = xm.mesh_reduce('valid_loss_reduce', valid_loss, lambda alist: sum(alist) / len(alist))
        valid_labels_concat = xm.mesh_reduce('valid_labels_concat', valid_labels, lambda alist: np.concatenate(alist))
        valid_preds_concat = xm.mesh_reduce('valid_preds_concat', valid_preds, lambda alist: np.concatenate(alist, axis=0))
        valid_confusion_mat = confusion_matrix(valid_labels_concat, np.argmax(valid_preds_concat, axis=1))
        valid_acc = Record.get_acc(valid_confusion_mat)
        valid_auc = Record.get_auc(valid_labels_concat, valid_preds_concat)
        xm.master_print(f"Valid Loss: {valid_loss_avg:.4f}  Valid Acc: {valid_acc}%  Valid AUC: {valid_auc}")
        xm.master_print(valid_confusion_mat)

        scheduler.step(valid_loss_avg)
        gc.collect()
        
        if rank == 0:
            train_losses.append(train_loss_avg)
            train_accs.append(train_acc)
            train_aucs.append(train_auc)
            valid_losses.append(valid_loss_avg)
            valid_accs.append(valid_acc)
            valid_aucs.append(valid_auc)
            lr_scheduler.append(optimizer.param_groups[0]['lr'])
            
    if rank == 0:
        # All tpu cores need to do below operations if without "if rank == 0"
        np.save(f"{dir}/train_losses", np.array(train_losses))
        np.save(f"{dir}/train_accs", np.array(train_accs))
        np.save(f"{dir}/train_aucs", np.array(train_aucs))
        np.save(f"{dir}/valid_losses", np.array(valid_losses))
        np.save(f"{dir}/valid_accs", np.array(valid_accs))
        np.save(f"{dir}/valid_aucs", np.array(valid_aucs))
        np.save(f"{dir}/lr_scheduler", np.array(lr_scheduler))
        
    # Only main process save model weights
    xm.save(model.state_dict(), f"{dir}/pretrained_model.bin")

spawn() takes a function (the "map function"), a tuple of arguments (the placeholder flags dict), the number of processes to create, and whether to create these new processes by "forking" or "spawning." While spawning new processes is generally recommended, Colab only supports forking.

spawn() will create eight processes, one for each Cloud TPU core, and call _mp_fn() -- the map function -- on each process. The inputs to _mp_fn() are an index (zero through seven, process id) and the placeholder flags. When the processes acquire their device they actually acquire their corresponding Cloud TPU core automatically.

In [None]:
%%time

FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

## Plot

In [None]:
def get_first_occur(learning_rates_list):
    learning_rates_list = learning_rates_list.tolist()
    unique_lrs = np.sort(np.unique(learning_rates_list))
    first_occur = []
    for lr in unique_lrs[:-1]:
        first_occur.append(learning_rates_list.index(lr))
    return np.array(first_occur) - 1


train_losses = np.load(f"{dir}/train_losses.npy")
train_accs = np.load(f"{dir}/train_accs.npy")
train_aucs = np.load(f"{dir}/train_aucs.npy")
valid_losses = np.load(f"{dir}/valid_losses.npy")
valid_accs = np.load(f"{dir}/valid_accs.npy")
valid_aucs = np.load(f"{dir}/valid_aucs.npy")
lr_scheduler= np.load(f"{dir}/lr_scheduler.npy")
first_occur = get_first_occur(lr_scheduler)

fig, ax = plt.subplots(2, 2, figsize=(20, 16))

best_train_loss = train_losses[np.argmin(valid_losses)]
best_valid_loss = np.min(valid_losses)
ax[0][0].plot(range(1, len(train_losses)+1), train_losses, "bo-", label="train_loss")
ax[0][0].plot(range(1, len(train_losses)+1), valid_losses, "go-", label="valid_loss")
ax[0][0].scatter(first_occur+1, train_losses[first_occur], s=300, c="r", marker="*")
ax[0][0].scatter(first_occur+1, valid_losses[first_occur], s=300, c="r", marker="*")
ax[0][0].legend()
ax[0][0].grid()
ax[0][0].set_title(f"Best Train Loss: {np.round(best_train_loss, 4)}  Best Valid Loss: {np.round(best_valid_loss, 4)}")

best_train_acc = train_accs[np.argmax(valid_accs)]
best_valid_acc = np.max(valid_accs)
ax[0][1].plot(range(1, len(train_accs)+1), train_accs, "bo-", label="train_acc")
ax[0][1].plot(range(1, len(train_accs)+1), valid_accs, "go-", label="valid_acc")
ax[0][1].scatter(first_occur+1, train_accs[first_occur], s=300, c="r", marker="*")
ax[0][1].scatter(first_occur+1, valid_accs[first_occur], s=300, c="r", marker="*")
ax[0][1].legend()
ax[0][1].grid()
ax[0][1].set_title(f"Best Train Acc: {np.round(best_train_acc, 3)}  Best Valid Acc: {np.round(best_valid_acc, 3)}")

best_train_auc = train_aucs[np.argmax(valid_aucs)]
best_valid_auc = np.max(valid_aucs)
ax[1][0].plot(range(1, len(train_aucs)+1), train_aucs, "bo-", label="train_auc")
ax[1][0].plot(range(1, len(train_aucs)+1), valid_aucs, "go-", label="valid_auc")
ax[1][0].scatter(first_occur+1, train_aucs[first_occur], s=300, c="r", marker="*")
ax[1][0].scatter(first_occur+1, valid_aucs[first_occur], s=300, c="r", marker="*")
ax[1][0].legend()
ax[1][0].grid()
ax[1][0].set_title(f"Best Train Auc: {np.round(best_train_auc, 3)}  Best Valid Auc: {np.round(best_valid_auc, 3)}")

ax[1][1].plot(range(1, len(lr_scheduler)+1), lr_scheduler, "ro-", label="learning_rate")
ax[1][1].legend()
ax[1][1].grid()
ax[1][1].set_title(f"MinLR={np.round(min(lr_scheduler), 6)}")

os.remove(f'{dir}/train_losses.npy')
os.remove(f'{dir}/train_accs.npy')
os.remove(f'{dir}/train_aucs.npy')
os.remove(f'{dir}/valid_losses.npy')
os.remove(f'{dir}/valid_accs.npy')
os.remove(f'{dir}/valid_aucs.npy')
os.remove(f'{dir}/lr_scheduler.npy')