> NOTE: This notebook doesn't utilize TPU properly. TPU idle time is more than 50% most of the time.

In [None]:
!pip install "torchvision" "torchtext==0.9"

## TPU Setup

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.8 --apt-packages libomp5 libopenblas-dev

### Import Packages

In [None]:
import json
import os
from typing import Dict, List
import logging

import numpy as np
import pandas as pd

import tqdm

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('darkgrid')

import cv2
import albumentations as A
from albumentations.core.composition import Compose
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, TensorDataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import torch
import torchvision.models as models
from torch import nn
from torch.optim import AdamW, Adam
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau


import torch_xla.core.xla_model as xm

### Set seed for everythin(numpy, torch and python)

In [None]:
from pytorch_lightning import seed_everything
seed_everything(42)

### Set Directories

In [None]:
ROOT_DIR = '/kaggle/input/cassava-leaf-disease-classification/'
TRAIN_IMAGES_FOLDER = 'train_images'
TEST_IMAGES_FOLDER = 'test_images'
TRAIN_CSV = 'train.csv'
SAMPLE_SUBMISSION_CSV = 'sample_submission.csv'
LABEL_NUM_TO_DISEASE_MAP_JSON = 'label_num_to_disease_map.json'


image_dir = os.path.join(ROOT_DIR, TRAIN_IMAGES_FOLDER)

### Set HyperParameters

In [None]:
LEARNING_RATE = 1e-4
MAX_EPOCHS = 1
BATCH_SIZE = 4

### Other Parameters

In [None]:
NUM_WORKERS = 3

IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512

### Reading Data

In [None]:
with open(os.path.join(ROOT_DIR, LABEL_NUM_TO_DISEASE_MAP_JSON), 'r') as file:
    label_to_disease = json.load(file)
    print(json.dumps(label_to_disease, indent=4))

In [None]:
label_to_disease_mapping = {int(key): value for key, value in label_to_disease.items()}
label_to_disease_mapping

In [None]:
train_df = pd.read_csv(os.path.join(ROOT_DIR, TRAIN_CSV))
train_df.head()

In [None]:
train_df.shape

## Preparing Dataset

In [None]:
class ImageDataset(Dataset):
    """
    Cassava Leaf Dataset
    """
    def __init__(self,
                image_names: List[str],
                labels: List[int],
                image_dir: str, 
                transforms,
                labels_to_ohe: bool=False,
                num_class: int = 5):        
        self.image_names = image_names
        self.image_dir = image_dir
        self.transforms = transforms
        self.num_class = num_class

        if labels_to_ohe:
            self.labels = np.zeros((len(labels), num_class))
            self.labels[np.arange(len(labels)), np.array(labels)] = 1
        else:
            self.labels = np.array(labels)


    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx: int)->Dict[str, np.array]:
        image_path = os.path.join(self.image_dir, self.image_names[idx])        
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)    

        target = self.labels[idx]

        transformed_image = self.transforms(image=image)['image']
        sample = {'image_path': image_path, 'image': transformed_image, 'target': torch.tensor(target)}

        return sample


In [None]:
class ImageDataModule(pl.LightningDataModule):
    def __init__(self,
                 df,
                 train_transforms,
                 valid_transforms,
                 image_dir,
                 fold_num=0):
        super().__init__()
        self.df = df
        self.train_transforms = train_transforms
        self.valid_transforms = valid_transforms
        self.image_dir = image_dir
        self.fold_num = fold_num
    
    def prepare_data(self):
        pass

    def setup(self, stage=None):
        
        folds = StratifiedKFold(n_splits=5, shuffle=True)
        
        train_indexes, valid_indexes = list(folds.split(self.df, self.df['label']))[self.fold_num]
        
        train_df = self.df.iloc[train_indexes]
        valid_df = self.df.iloc[valid_indexes]

        self.train_dataset = ImageDataset(image_names=train_df.image_id.values, 
                                        labels=train_df.label.values, 
                                        image_dir=self.image_dir, 
                                        transforms=self.train_transforms)

        self.valid_dataset = ImageDataset(image_names=valid_df.image_id.values, 
                                        labels=valid_df.label.values, 
                                        image_dir=self.image_dir, 
                                        transforms=self.valid_transforms)                                        

    def train_dataloader(self):  
        """
        sampler = torch.utils.data.distributed.DistributedSampler(
            self.train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        """    
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=BATCH_SIZE,
            #sampler=sampler,
            num_workers=NUM_WORKERS,            
            shuffle=True
        )
        return train_loader

    def val_dataloader(self):    
        """
        sampler = torch.utils.data.distributed.DistributedSampler(
            self.valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)        
        """
        valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=BATCH_SIZE,
            #sampler=sampler,
            num_workers=NUM_WORKERS,            
            shuffle=False
        )
        return valid_loader

    def test_dataloader(self):
        return None


### Image Augmentation for Train and Test


In [None]:
train_augs = A.Compose([    
    A.RandomResizedCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, p=1.0),
    A.Flip(),    
    A.RandomBrightnessContrast(),
    A.ShiftScaleRotate(),
    A.OneOf([
            A.MotionBlur(p=.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.2),
    A.Normalize(),
    ToTensorV2(),
])

valid_augs = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, p=1.0),
    A.Normalize(),
    ToTensorV2(),
])

In [None]:
class ClassifierModule(pl.LightningModule):
    def __init__(self, learning_rate=LEARNING_RATE):
        super().__init__()        
        self.metric = pl.metrics.Accuracy()
        self.learning_rate = learning_rate        
        self.model = models.resnet101(pretrained=True)        
        self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=5)                
        
    def forward(self, x):
        batch_size, _, _, _ = x.shape
        x = self.model(x)        
        
        return x.reshape(batch_size, -1)
        

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=0.001)
        scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=2)

        return (
            [optimizer],
            [{'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'valid_loss'}],
        )    
    
    def _get_loss(self, y_hat, y):
        return nn.CrossEntropyLoss()(y_hat, y)

    def training_step(self, batch, batch_idx):
        image = batch['image']
        y = batch['target']
        y_hat = self(image)
        y_hat_argmax = y_hat.argmax(1)
        loss = self._get_loss(y_hat, y)        
        score = self.metric(y_hat_argmax, y)        
        
        logs = {'train_loss': loss, 'train_accuracy': score}
        return {
            'loss': loss,
            'log': logs,
            'progress_bar': logs,
            'logits': y_hat,
            'target': y,
            'train_accuracy': score,
        }        
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        y_true = torch.cat([x['target'] for x in outputs])
        y_pred = torch.cat([x['logits'] for x in outputs])
        score = self.metric(y_pred.argmax(1), y_true)
        
        logs = {'train_loss': avg_loss, 'train_accuracy': score}
        
        return {'log': logs, 'progress_bar': logs}

    def validation_step(self, batch, batch_idx):
        image = batch['image']
        y = batch['target']
        y_hat = self(image)
        y_hat_argmax = y_hat.argmax(1)
        loss = self._get_loss(y_hat, y)
        score = self.metric(y_hat_argmax, y)
        logs = {'valid_loss': loss, 'valid_accuracy': score}                

        return {
            'loss': loss,
            'log': logs,
            'progress_bar': logs,
            'logits': y_hat,
            'target': y,
            f'valid_accuracy': score,
        }        

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        y_true = torch.cat([x['target'] for x in outputs])
        y_pred = torch.cat([x['logits'] for x in outputs])
        score = self.metric(y_pred.argmax(1), y_true)
        
        logs = {'valid_loss': avg_loss, f'valid_accuracy': score, 'accuracy': score}
                
        return {'valid_loss': avg_loss, 'log': logs, 'progress_bar': logs}

### Training

In [None]:
# Data Module, change fold_num
fold_num = 0
data_module = ImageDataModule(df=train_df, train_transforms=train_augs, valid_transforms=valid_augs, image_dir=image_dir, fold_num=fold_num)

trainer = pl.Trainer(
        deterministic=True,
#         checkpoint_callback=ModelCheckpoint(monitor='train_loss', save_top_k=1, filename='resnet101-foldnum-0_{epoch}_{valid_loss:.4f}_{accuracy:.4f}', mode='min'),
        #gpus=1 if torch.cuda.is_available() else 0,        
        tpu_cores=8,
        max_epochs=MAX_EPOCHS,
        num_sanity_val_steps=1,        
        weights_summary='top',
        callbacks = [EarlyStopping(monitor='valid_loss', patience=5, mode='min')]
)


lightning = ClassifierModule()

In [None]:
trainer.fit(lightning, data_module)

### Resources:

1 - Inspired by Artgor's notebook [Cassava disease identification with lightning](https://www.kaggle.com/artgor/cassava-disease-identification-with-lightning/notebook)

2 - [TPU SUPPORT](https://pytorch-lightning.readthedocs.io/en/stable/tpu.html)