In [None]:
import sys
sys.path.append('../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master')
sys.path.append('../input/timm-h/pytorch-image-models-master_h')

In [None]:
%matplotlib inline
import os
import cv2
import glob
import random
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
import albumentations as albu
from albumentations.pytorch import ToTensorV2

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import metrics

import timm
from timm import create_model

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

## Data

In [None]:
os.listdir('/kaggle/input/cassava-leaf-disease-classification')

In [None]:
data_dir = '/kaggle/input/cassava-leaf-disease-classification'
train = pd.read_csv(os.path.join(data_dir, 'train.csv'))
sub = pd.read_csv(os.path.join(data_dir, 'sample_submission.csv'))

In [None]:
train.head()

In [None]:
sub.head()

## Seed Setting

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

## Label Distribution

In [None]:
sns.countplot(x='label', data=train)
plt.show()

In [None]:
train['label'].value_counts().sort_index()

## Dataset

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, data_dir, transform=None, phase='train', df=None):
        self.df = df
        self.data_dir = data_dir
        self.transform = transform
        self.phase = phase
        if self.phase == 'test':
            img_dir = 'test_images'
        else:
            img_dir = 'train_images'
        self.img_path = glob.glob(os.path.join(self.data_dir, img_dir, '*.jpg'))

    def __len__(self):
        if self.df is None:
            return len(self.img_path)
        else:
            return len(self.df)

    def __getitem__(self, idx):
        
        if self.phase == 'test':
            target_img_path = self.img_path[idx]
        else:
            row = self.df.iloc[idx]
            target_img_id = row['image_id']
            target_img_path = os.path.join(self.data_dir, 'train_images', f'{target_img_id}')
            
        img = cv2.imread(target_img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_id = os.path.basename(target_img_path)

        if self.transform is not None:
            img = self.transform(img, self.phase)
        else:
            img = torch.from_numpy(img.transpose((2, 0, 1)))
            img = img / 255.

        if self.phase == 'test':
            return img, img_id
        
        else:
            label = self.df[self.df['image_id'] == img_id]['label'].values
            label = torch.tensor(label, dtype=torch.long)

            return img, label

## ImageTransform

In [None]:
class ImageTransform:
    def __init__(self, img_size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transform = {
            'train': albu.Compose([
                albu.RandomShadow(p=0.5),
                albu.RandomResizedCrop(img_size, img_size, interpolation=cv2.INTER_AREA),
                albu.ColorJitter(p=0.5),
                albu.CLAHE(p=0.5),
                albu.HorizontalFlip(p=0.5),
                albu.VerticalFlip(p=0.5),
                albu.Transpose(p=0.5),
                albu.ShiftScaleRotate(p=0.5),
                albu.OneOf([
                    albu.Blur(p=1.0),
                albu.GaussianBlur(p=1.0)
                ], p=0.5),
                albu.CoarseDropout(max_height=15, max_width=15, min_holes=3, p=0.5),
                albu.Normalize(mean, std),
                ToTensorV2(),
            ], p=1.0),

            'val': albu.Compose([
                albu.Resize(img_size, img_size),
                albu.Normalize(mean, std),
                ToTensorV2(),
            ], p=1.0),

            'test': albu.Compose([
                albu.Resize(img_size, img_size),
                albu.HorizontalFlip(p=0.5),
                albu.VerticalFlip(p=0.5),
                albu.Normalize(mean, std),
                ToTensorV2(),
            ], p=1.0)
        }

    def __call__(self, img, phase='train'):
        augmented = self.transform[phase](image=img)
        augmented = augmented['image']

        return augmented

In [None]:
# Sanity Check
transform = ImageTransform()
dataset = CassavaDataset(data_dir, transform, phase='train', df=train)
img, label = dataset.__getitem__(0)
print(img.size(), label)
print(img.max())
print(img.min())

In [None]:
dataloader = DataLoader(dataset, batch_size=8)

imgs, labels = next(iter(dataloader))
print(imgs.size())

In [None]:
# Sanity Check
transform = ImageTransform()
dataset = CassavaDataset(data_dir, transform, phase='test', df=None)
img, label = dataset.__getitem__(0)
print(img.size(), label)
print(img.max())
print(img.min())

In [None]:
dataloader = DataLoader(dataset, batch_size=1)

imgs, labels = next(iter(dataloader))
print(imgs.size())

## Lightning DataModule

In [None]:
class CassavaDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, cfg, transform, cv, fold):
        super(CassavaDataModule, self).__init__()
        self.data_dir = data_dir
        self.cfg = cfg
        self.transform = transform
        self.cv = cv
        self.fold = fold


    def prepare_data(self):
        # Prepare Data
        self.df = pd.read_csv(os.path.join(self.data_dir, 'train.csv'))


    def setup(self, stage=None):
        # Validation
        self.df['fold'] = -1
        for i, (trn_idx, val_idx) in enumerate(self.cv.split(self.df, self.df['label'])):
            self.df.loc[val_idx, 'fold'] = i
        train = self.df[self.df['fold'] != self.fold].reset_index(drop=True)
        val = self.df[self.df['fold'] == self.fold].reset_index(drop=True)
        
        # Dataset
        self.train_dataset = CassavaDataset(self.data_dir, self.transform, phase='train', df=train)
        self.val_dataset = CassavaDataset(self.data_dir, self.transform, phase='val', df=val)
        self.test_dataset = CassavaDataset(self.data_dir, self.transform, phase='test', df=None)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.train['batch_size'],
                          pin_memory=True,
                          num_workers=4,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.train['batch_size'],
                          pin_memory=True,
                          num_workers=4,
                          shuffle=False)

    def test_dataloader(self):
        batch_size = min(len(self.test_dataset), self.cfg.train['batch_size'])
        return DataLoader(self.test_dataset,
                          batch_size=batch_size,
                          pin_memory=True,
                          num_workers=4,
                          shuffle=False)

## Model

In [None]:
class Timm_model(nn.Module):
    def __init__(self, model_name='efficientnet_b0', pretrained=True, out_dim=5):
        super(Timm_model, self).__init__()
        self.base = create_model(model_name, pretrained=pretrained)

        if 'efficientnet' in model_name:
            self.base.classifier = nn.Linear(in_features=self.base.classifier.in_features, out_features=out_dim)
        elif 'vit' in model_name:
            self.base.head = nn.Linear(in_features=self.base.head.in_features, out_features=out_dim)
        else:
            self.base.fc = nn.Linear(in_features=self.base.fc.in_features, out_features=out_dim)

    def forward(self, x):
        return self.base(x)

In [None]:
# Sanity Check
z = torch.randn(4, 3, 224, 224)
model = Timm_model(pretrained=False)
out = model(z)
print(out.size())

## Lightning Module

In [None]:
class CassavaLightningSystem(pl.LightningModule):
    def __init__(self, net, cfg, experiment=None):
        super(CassavaLightningSystem, self).__init__()
        self.net = net
        self.cfg = cfg
        self.experiment = experiment
        self.criterion = nn.CrossEntropyLoss()
        self.best_loss = 1e+9
        self.best_acc = None
        self.epoch_num = 0
        self.acc_fn = metrics.Accuracy()
        self.loss_list = []
        self.acc_list = []

    def configure_optimizers(self):
        self.optimizer = optim.AdamW(self.parameters(), lr=self.cfg.train['lr'], weight_decay=1e-5)
        self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
#         self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.cfg.train['epoch'], eta_min=0)

        return [self.optimizer], [self.scheduler]

    def forward(self, x):
        return self.net(x)

    def step(self, batch):
        inp, label = batch
        out = self.forward(inp)
        loss = self.criterion(out, label.squeeze())

        return loss, label, torch.sigmoid(out)

    def training_step(self, batch, batch_idx):
        loss, label, logits = self.step(batch)

        if self.experiment is not None:
            logs = {'train/loss': loss.item()}
            self.experiment.log_metrics(logs, step=batch_idx)

        return {'loss': loss, 'logits': logits, 'labels': label}

    def validation_step(self, batch, batch_idx):
        loss, label, logits = self.step(batch)

        if self.experiment is not None:
            val_logs = {'val/loss': loss.item()}
            self.experiment.log_metrics(val_logs, step=batch_idx)

        return {'val_loss': loss, 'logits': logits, 'labels': label}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logits = torch.cat([x['logits'] for x in outputs])
        labels = torch.cat([x['labels'] for x in outputs])
        
        # Accuracy
        acc = self.acc_fn(logits, labels.squeeze())
                
        print(f'Epoch: {self.epoch_num}  Loss: {avg_loss.item():.4f}  Acc {acc.item():.4f}')
        self.loss_list.append(avg_loss.item())
        self.acc_list.append(acc.item())
        
        if self.experiment is not None:
            logs = {'val/epoch_loss': avg_loss, 'val/epoch_acc': acc}
            # Logging
            self.experiment.log_metrics(logs, step=self.epoch_num)

        # Save Weights
        if self.best_loss > avg_loss:
            self.best_loss = avg_loss.item()
            self.best_acc = acc.item()
            expname = self.cfg.exp['exp_name']
            filename = f'{expname}_epoch_{self.epoch_num}_loss_{self.best_loss:.3f}_acc_{self.best_acc:.3f}.pth'
            torch.save(self.net.state_dict(), filename)
            if self.experiment is not None:
                self.experiment.log_model(name=filename, file_or_folder='./'+filename)
                os.remove(filename)
            
        # Update Epoch Num
        self.epoch_num += 1

        return {'avg_val_loss': avg_loss}

    def test_step(self, batch, batch_idx):
        inp, img_id = batch
        out = self.forward(inp)
        logits = torch.sigmoid(out)

        return {'preds': logits, 'image_id': img_id}

    def test_epoch_end(self, outputs):
        preds = torch.cat([x['preds'] for x in outputs])
        preds = preds.detach().cpu().numpy()
        preds = pd.DataFrame(preds, columns=[f'label_{c}' for c in range(5)])
        # [tuple, tuple]
        img_ids = [x['image_id'] for x in outputs]
        # [list, list]
        img_ids = [list(x) for x in img_ids]
        img_ids = list(itertools.chain.from_iterable(img_ids))
        self.sub = preds
        self.sub.insert(0, 'image_id', img_ids)

        return None
    
    
    # learning rate warm-up
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
        # warm up lr
        if self.trainer.global_step < 500:
            lr_scale = min(1., float(self.trainer.global_step + 1) / float(500))
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.cfg.train['lr']

        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

## Config

In [None]:
class cfg:
    exp = {
        'exp_name': 'test'
    }
    
    data = {
        'img_size': 256,
        'n_splits': 5
    }
    
    train = {
        'batch_size': 64,
        'epoch': 10,
        'seed': 42,
        'lr': 0.005,
        'model_name': 'efficientnet_b0'
    }

## Trainer

In [None]:
# Data Dir  #################################################################
data_dir = '/kaggle/input/cassava-leaf-disease-classification'
seed_everything(cfg.train['seed'])

# Validation  ###############################################################
cv = StratifiedKFold(n_splits=cfg.data['n_splits'], shuffle=True, random_state=cfg.train['seed'])

# Transform  ################################################################
transform = ImageTransform(img_size=cfg.data['img_size'])

In [None]:
def main(data_dir, transform, cfg, cv, fold, TTA=5):
    # Model  ####################################################################
    net = Timm_model(model_name=cfg.train['model_name'], pretrained=False)
    
    # Lightning Module  #########################################################
    dm = CassavaDataModule(data_dir, cfg, transform, cv, fold=fold)
    model = CassavaLightningSystem(net, cfg, experiment=None)

    trainer = Trainer(
        logger=False,
        max_epochs=cfg.train['epoch'],
        gpus=1,
        amp_backend='apex',
        amp_level='O2',
            )

    # Train & Test  ############################################################
    # Train
    trainer.fit(model, datamodule=dm)

    # Test
    if TTA > 0:
        for i in range(TTA):
            trainer.test(model, datamodule=dm)
            if i == 0:
                res = model.sub
            else:
                for j in range(5):
                    res[f'label_{j}'] += model.sub[f'label_{j}']

    else:
        trainer.test(model, datamodule=dm)
        res = model.sub

    res.to_csv(f'submission_fold{fold}.csv', index=False)
    
    del net, dm
    torch.cuda.empty_cache()
    
    return model

In [None]:
TTA = 3
models = []

for fold in range(cfg.data['n_splits']):
    m = main(data_dir, transform, cfg, cv, fold, TTA)
    models.append(m)
    del m

In [None]:
# Summarize Predictions
sub_paths = glob.glob('submission_fold*')
for i, path in enumerate(sub_paths):
    tmp = pd.read_csv(path)
    
    if i == 0:
        res = tmp
    else:
        for j in range(5):
            res[f'label_{j}'] += tmp[f'label_{j}']
            
label_cols = [c for c in res.columns if c != 'image_id']
res['label'] = np.argmax(res[label_cols].values, axis=1)
res = res[['image_id', 'label']]
res.to_csv('submission.csv', index=False)

## Learning Plot

In [None]:
history = pd.DataFrame({
    'loss': models[0].loss_list,
    'acc': models[0].acc_list
})

fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(16, 6))

for ax, label in zip(axes.ravel(), ['loss', 'acc']):
    ax.plot(history[label])
    ax.set_title(label)
    ax.set_xlabel('Epoch')

plt.tight_layout()
plt.show()