# Basic Training Notebook
Without using One Cycle Policy, this is the most basic method of training. 

First, make sure you set `tpu_utility_1` and `tpu_cache_ds_utils` as utility script and **add** them to this notebook. Make sure the **scripts are python scripts when doing Quick Save rather than Jupyter Notebook** by changing them at File $\rightarrow$ Editor Type $\rightarrow$ Script, or the import will fail. 

In [None]:
from tpu_utility_1 import *

# For jpeg file we use this library
!apt-get install libturbojpeg
!pip install jpeg4py

from IPython.display import clear_output
clear_output()

Below is the starting code cell you get when creating a new Kaggle notebook. We comment out the `for dirname...` to prevent it listing an extremely long list of items (if you have lots of items). 

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import albumentations

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, roc_auc_score

import matplotlib.pyplot as plt
import jpeg4py as jpeg  # will fail if you don't run the first code cell. 
import pathlib
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
from torchvision import models

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

We list down some `FLAGS` (python dict containing the configurations).

In [None]:
FLAGS = {
    "data_dir": Path("../input/plant-pathology-2021-fgvc8/train_images"),
    "bs": 32,  # Batch size PER TPU CORE. Total bs = bs * xrt.world_size()
    "num_workers": os.cpu_count(),  # for pre-fetching data. 
    "lr": 0.001,  # this will multiply by xrt.world_size() for multiprocessing. 
    "momentum": 0.9,  # ONLY IF using SGD
    "num_epochs": 10,
    "num_cores": 8 if os.environ.get("TPU_NAME", None) else 1,
    "log_steps": 20,  # We are not using this but defined for historical reasons. 
    "metrics_debug": False,
    "seed": 1397,
    "save_path": "/kaggle/working/model.pth"
}

First, preprocess the data into the format we want to feed into the model. 

In [None]:
diseases = ['healthy', 'rust', 'scab', 'frog_eye_leaf_spot', 'powdery_mildew', 'complex']
df = pd.read_csv('../input/plant-pathology-2021-fgvc8/train.csv')
df.head()

In [None]:
labels = df.labels.str.split()
occ = np.zeros((len(labels),len(diseases)))
count = 0
for i in diseases:
    occ[:,count] = [i in list for list in labels]
    count+=1

dis_df = pd.DataFrame(occ.astype(np.float32), columns=diseases)
dis_df.head()

In [None]:
torch_df = pd.concat([df["image"], dis_df], axis=1)
torch_df.head()

We are lazy enough to use a StratifiedKFold to split our data (than do it ourselves). In real, training should be done with the number of folds you specify, repeatedly. Here, we will just fetch the last fold as a demonstration. You can easily modify the code to your need. 

In [None]:
skf = StratifiedKFold(n_splits=5)

tdf_np = df.to_numpy()
X, y = tdf_np[:, 0], tdf_np[:, 1]
for train_index, test_index in skf.split(X, y): pass

After this, we shall now have our `train_index` and `test_index` in place. Be sure to not accidentally delete them or you would need to run again. Optionally, see if you could set seed to get reproducible outcome. 

Our final array would look like this.

In [None]:
torch_df.to_numpy()[train_index]

## Defining Dataset
This is the usual thing you do in PyTorch to define how to fetch your data. No code change is required. You are required to write this code yourself **unless you have a `xcd.CachedDataset()` saved** (in which case you don't need to define these below. 

In [None]:
class AppleDataset(D.Dataset):
    def __init__(self, parent_path, df=None, shuffle=False,
                 seed=None, transform=None, split=None):
        """
        parent_path: (pathlib.Path) Parent path of images. 
        df: (pandas.DataFrame) y-labels dataframe. 
        shuffle: (Boolean) Shuffle dataset? Default: False. 
        seed: (integer) seed. Default: None. 
        transforms: (Albumentations) Transformation to images, default: None. 
            Requires writing more code to use PyTorch Transform on cpu. 
        split: (python list) Splits index for train test split. 
            Here we pass in train_index or test_index. 
        """
        self.df_np = df.to_numpy()
        self.parent_path = Path(parent_path) if type(parent_path) == str else parent_path
        self.seed = seed
        self.transform = transform
        
        if type(split) != type(None): self.df_np = self.df_np[split]
            
    def __len__(self):
        return len(self.df_np)
    
    def __getitem__(self, idx):
        item = self.df_np[idx]
        item_path = self.parent_path / item[0]  # 'image' column
        image = jpeg.JPEG(item_path).decode()
        
        if self.transform is not None:  # only works for albumentations library. 
            image = self.transform(**{"image": image})["image"]
            
        image = torch.from_numpy(image).permute(2, 0, 1)  # HWC -> CHW format
        target = torch.from_numpy(item[1:].astype(np.float32))
        
        return image, target  

Then we list the albumentation transforms we want to perform on the dataset. 

In [None]:
size = 224  # final image size

train_transform = albumentations.Compose([
    albumentations.RandomResizedCrop(height=size, width=size, always_apply=True, scale=(0.5, 1.0)),
    albumentations.Flip(),
    albumentations.Rotate(limit=180, p=0.75),
    albumentations.CoarseDropout(max_holes=4, p=0.1),
    albumentations.RandomBrightnessContrast(p=0.75),
    albumentations.Normalize(always_apply=True, p=1.0)
])

val_transform = albumentations.Compose([
    albumentations.Resize(height=size, width=size, always_apply=True),
    albumentations.Normalize(always_apply=True, p=1.0)
])

**Define** our `get_dataset` magic function. 

In [None]:
def get_dataset():
    parent_path = FLAGS["data_dir"]
    train_ds = AppleDataset(parent_path, torch_df, shuffle=True, seed=FLAGS["seed"],
                           transform=train_transform, split=train_index)
    val_ds = AppleDataset(parent_path, torch_df, shuffle=False, seed=FLAGS["seed"],
                         transform=val_transform, split=test_index)
    
    return train_ds, val_ds  # required return format

## Model definition
Let's define our model. 

In [None]:
num_classes = len(diseases)

model = models.resnet50(pretrained=True)
num_features = model.fc.in_features

# freeze models
for param in model.parameters():
    param.requires_grad_(False)

Then we can define the head. Here we're using a more complex head. But you can always do `model.fc = nn.Linear(num_features, num_classes)` if you want to. 

In [None]:
def create_head(num_features, num_classes, dropout=0.1, act_func=nn.ReLU):
    features_lst = [num_features, num_features // 2, num_features // 4]
    layers = []
    
    for in_f, out_f in zip(features_lst[:-1], features_lst[1:]):
        layers.append(nn.Linear(in_f, out_f))
        layers.append(act_func())
        layers.append(nn.BatchNorm1d(out_f))
        if dropout != 0: layers.append(nn.Dropout(dropout))
            
    layers.append(nn.Linear(features_lst[-1], num_classes))
    return nn.Sequential(*layers)


model.fc = create_head(num_features, num_classes)

## LR Finder

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"])
criterion = nn.BCEWithLogitsLoss()

Define the dataloader

In [None]:
train_ds, val_ds = get_dataset()
dls = dataloader(train_ds, val_ds, FLAGS)

lr_finder(model, opt, criterion, dls, device=xm.xla_device())

# Then Restart Notebook
The reason is because we've called `xm.xla_device()` so multiprocessing cannot run anymore unless we don't use multiprocessing. Comment out the previous cell. 

**So long you call `xm.xla_device()` ONCE PER KERNEL RESTART** you won't have this specific error (calling 8 device but using 1) (or something like that). 

**To NOT USE MULTIPROCESSING**, change the `FLAGS['num_cores'] = 1` forcefully. 

Otherwise, you can just do a normal training loop only changing `device` to `xm.xla_device()` AND `opt.step()` to `xm.optimizer_step(opt, barrier=True)`. 

First we will define `train_tpu()`, where **every of our function requires to be inside this**. 

In [None]:
SERIAL_EXEC = xmp.MpSerialExecutor()
WRAPPED_MODEL = xmp.MpModelWrapper(model)

Ultimately in the `train_loop_fn` you define as normal with additional parameters for TPU. These are `tracker = xm.RateTracker()` at the beginning of the loop, using `xm.optimizer_step(opt)` **instead of `opt.step()`**, and `tracker.add(FLAGS["bs"])` inside looping through dataloader. That's it! For `test_loop_fn`, you don't need to change your code. Just ensure you group whatever is required inside (validation loop) and it will work fine. 

What's more is since we're doing multiprocessing, setting `print(..., flush=True)` will flush printing. 

Also notice that `train_loop_fn(loader)` and `test_loop_fn(loader)` takes one (and only one) input, which is the `train_dl` and `val_dl` respectively. These are available inside `dls` hence passing `dls` into the function `train_cycle_distrib` suffice. 

`train_cycle_distrib()` returns a value `returned_val` which contains all the value *grouped together* returned from the `test_loop_fn`. Currently, **it will break if `test_loop_fn` didn't return anything** (one will fix this in the future). 

In [None]:
def train_tpu():
    torch.manual_seed(FLAGS["seed"])
    
    dls = distrib_dataloader(get_dataset, FLAGS)
    
    device = xm.xla_device()
    lr = FLAGS["lr"] * xm.xrt_world_size()
    model = WRAPPED_MODEL.to(device)
    
    # If you want scheduler. 
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=5, eta_min=0.005)
    
    
    def train_loop_fn(loader):
        tracker = xm.RateTracker()  # required in TPU training loop function multiprocessing.
        
        model.train()
        running_loss, total_samples = 0.0, 0
        
        for data, target in tqdm(loader):
            opt.zero_grad()
            data, target = data.to(device), target.to(device)
            
            # (Optional) For image, to use channels last. 
            data = data.to(memory_format=torch.channels_last)
            
            output = model(data)
            loss = criterion(output, target)
            
            # For one-hot encoding
            preds = torch.sigmoid(output).data > 0.5
            preds = preds.to(torch.float32)  # Don't do float16 since TPU supports bfloat16
            
            loss.backward()
            xm.optimizer_step(opt)
            scheduler.step()
            
            tracker.add(FLAGS["bs"])  # required for TPU. 
            
            running_loss += loss.item() * data.size(0)
            total_samples += data.size(0)
            
        print(f"Train loss: {running_loss / total_samples}", flush=True)
        
    
    def test_loop_fn(loader):
        total_samples = 0
        running_loss, f1Score = 0.0, 0.0
        # One didn't manage to get RocAucScore to work due to some errors. 
        
        model.eval()
        
        for data, target in tqdm(loader):
            data, target = data.to(device), target.to(device)
            data = data.to(memory_format=torch.channels_last)
            
            output = model(data)
            loss = criterion(output, target)
            
            # One hot encoding
            preds = torch.sigmoid(output).data > 0.5
            preds = preds.to(torch.float32)  # use float32, TPU convert it to bfloat16.
            
            running_loss += loss.item() * data.size(0)
            total_samples += data.size(0)
            
            # Other calculations for f1Score here
            target = target.cpu().to(torch.int).numpy()
            preds = preds.cpu().to(torch.int).numpy()
            f1Score += f1_score(target, preds, average="weighted") * data.size(0)
            
        epoch_loss = running_loss / total_samples
        epoch_f1score = f1Score / total_samples
        
        print(f"""
            |   Val loss    |   Val F1Score   |
            |{epoch_loss:<15}|{epoch_f1score:<17}|
        """, flush=True)
        
        return epoch_loss, epoch_f1score, data, preds, target
    
    
    returned_val = train_cycle_distrib(dls, FLAGS, train_loop_fn, test_loop_fn, device=device)
#     epoch_loss, epoch_f1score, data, preds, target = returned_val

    return returned_val, model

### Multiprocessing function
This is the function that changes a lot and requires you to manually define yourself. Due to its flexibility and dependent on what you requires it to do inside the function, you requires to define it yourself rather than one hiding it away from you. Let's see how to define it. 

Usually this function is called `_mp_fn(rank)`. **The first argument must always be rank/index/whatever name you call it**. This is the argument that is always passed in by the TPU executor as a compulsory argument. Rank is the TPU index, signifying which TPU core it is allocated to. This is similar to GPU:0, GPU:1 if you are using multiple GPUs. 

Then, you can pass in any additional kwargs. The only thing **you cannot pass in** is `*args` and `**kwargs`. So, for example, we pass in additional argument `flags` inside. 

In [None]:
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    torch.set_default_tensor_type(torch.FloatTensor)  # This is compulsory. 
    
    returned_val, model = train_tpu()  # pass in the training function. 
    
    # One is saving model here at end of epoch. You could do other stuffs. 
    # For example, printing summary, and remember this is OPTIONAL. 
    # You could safely ignore the line below if you don't have anything you want 
    # to do. 
    if rank == 0: torch.save(model.state_dict(), FLAGS["save_path"])

The reason we ask for `if rank == 0` is because the TPU cores process stuffs not necessarily synchronously. Hence, some other cores aren't finished processing yet when others finished. So `rank == 0` will wait until it returns to the first TPU core (and all other TPU stopped processing) before whatever you want to do (here saving model) is done. If you didn't do this, you might raise exceptions (try it yourself and see what happens!)

**Finally, let's call the function and train!**

For the function below, the first argument is always `_mp_fn` (or other name depending on how you name your function), second argument is all the arguments you pass into `_mp_fn` (**except for `rank`**) in order. Third is the number of TPU to use. For single processing, pass in `nprocs=1`. For multiprocessing, pass in `nprocs=8`. Note, you can only pass in 1 or 8, not 2, not 4. Either only use 1 or all cores. Then, `start_method="fork"` is always fixed because the only choice is `fork`. 

## Important NOTE:
If using `nprocs=8`, due to required warming up of TPU such as moving data to TPU, preparing for multiprocessing, and other reasons that one doesn't figure out, you will wait **quite long** before the spawn process start. **This means waiting for up to 1 hour before the process start**. To check whether your program works fine, try setting `nprocs=1` during experimental phase. 

Even if setting `nprocs=1`, the process requires some time to warm up. Typically, this means the first epoch is used for warming up. So, for our case, we have 465 total batches here, and processing the first 2-5 batches will take much much longer (as in up to 100x longer) than upcoming batches. Just wait and the tqdm meter would speed up at some point.

And concurrent waiting (where the bar not moving at certain points after the first few batches) is due to preprocessing images on CPU bottleneck. Consider *caching your dataset* to smoothen this process. This could speed up about 3-10x training (although CPU will still be a bottleneck as it is fully utilized to fetch and prefetch data from disk, which is still slower than TPU training). 

This is why it isn't guaranteed using multiple TPUs will speed up training: because CPU is the bottleneck in both case, so using single TPU and multiple TPU doesn't make any big differences until you upgrade your CPU. **Caching dataset does makes a difference though**. With cached dataset, CPU is only 50% utilized to fetch and prefetch data when using a single TPU (in ones' case); while fully utilized when fetching for all 8 cores. This difference reduces training time by half when using all TPUs. If this is not the case, it doesn't make any difference (as in Colab you only have 2 vCPUs so it is always fully utilized irregardless of you using single or multiple TPU cores). 

In [None]:
# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS["num_cores"], start_method="fork")
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method="fork")