In [None]:
!pip install -q timm
!pip install -q torchcontrib
!pip install -q pytorch_lightning==1.2.5
!pip uninstall -q -y albumentations && pip install -q git+https://github.com/albumentations-team/albumentations
# !pip install -q wandb

In [None]:
import os

import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import torch.nn as nn

from PIL import Image
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.metrics import Metric
from torchcontrib.optim import SWA
from typing import List, Dict

#Augmentation
import albumentations as A
from torchvision import transforms as tsfm
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt

In [None]:
print(pl.__version__)

## Config

In [None]:
class CFG:
    seed = 42
    # Một số đường dẫn
    root_dir_origin = "../input/plant-pathology-2021-fgvc8/"
    root_dir_resized = "../input/resized-plantpathology2021fgvc8-train-data-new/resized_plant-pathology-2021-fgvc8_train_data"
    train_csv_path = os.path.join(root_dir_origin, 'train.csv')
    train_imgs_dir = os.path.join(root_dir_resized, 'resized_train_images_360_512')

    #     folds_csv_path = "../input/pp2021-dataset-gnueih/6folds_pp2021.csv"
    folds_csv_path = "../input/pp2021-kfold-tfrecords-0/folds.csv"
    
    num_classes = 5
    labels = np.array(['powdery_mildew',
                     'scab',
                     'complex',
                     'frog_eye_leaf_spot',
                     'rust',])    
    
    # Version cho logger và tên model sử dụng
    version = 'kag_final_v18'
    model_name = 'tf_efficientnet_b4_ns'
    
    # Các tham số training
    use_sgd=False #Sử dụng Adam cho tốc độ hội tụ nhanh hơn và do chỉ train 3 epoch 
    fl_alpha = 1.0  # tham số scale focal_loss - alpha
    fl_gamma = 2.0  # tham số điều chỉnh dạng (đường có dạng hàm exponential) - focal loss
    # Trọng số loss của từng class, tham khảo từ https://www.kaggle.com/crissallan/pytorchlightning-efficientnet-focalloss-inference
    cls_weight = [0.3648, 0.0813, 0.2184, 0.1066, 0.2290] 
    # Trọng số positive để cân bằng tỉ lệ nhãn (số lượng sample có bệnh a nhỏ hơn nhiều so với không có bệnh)
    pos_weight = [1.6990, 1.3010, 1.6990, 1.4771, 1.6990] 
    
    num_epochs = 3
    batch_size = 8
#     scheduler_freq = 1548//36
#     t_max = 36
    scheduler_freq = 1855//35 # Chia một epoch làm 35 lần cập nhật learning rate
    t_max = 35 # Set một cycle learning rate scheduler ứng với một epoch
    lr = 3e-4 # Learning rate khởi tạo cho optimizer
    min_lr = 1e-7 # Learning rate tối thiểu dùng CosineAnnealingWarmRestarts
    
    n_train_fold = 5 # Số fold dùng cho training
    reserve_fold = 5 # Fold index không dùng cho training
    
    num_workers = 4 
    accum_grad_batch = 1
    early_stop_delta = 1e-7
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Seed cho phép reproduction

In [None]:
seed_everything(CFG.seed)

## Data

In [None]:
# Dataset Class
class ImageDataset(Dataset):
    """ Leaf Disease Dataset """
    def __init__(self,
                image_names,
                labels,
                image_dir, 
                transforms):        
        self.image_names = image_names
        self.image_dir = image_dir
        self.transforms = transforms                
        self.labels = labels

    def __len__(self) -> int:
        return len(self.image_names)

    def __getitem__(self, idx: int):
        image_path = os.path.join(self.image_dir, self.image_names[idx])   
        image = Image.open(image_path).convert('RGB')
        # image = np.array(Image.open(image_path))
        target = self.labels[idx]
        if self.transforms is not None:
          # image = np.array(image)
#             image = self.transforms(image=np.array(image))['image']
            image = self.transforms(image)
        return image, target

In [None]:
# Pytorch Lightning Data Module 
class ImageDataModule(pl.LightningDataModule):
    def __init__(self,
                 df: pd.DataFrame,
                 train_transforms,
                 valid_transforms,
                 no_aug_transforms,
                 image_dir: str,
                 fold_num: int,
                 configurations: Dict[str, int]):
        super().__init__()
        self.df = df
        self.train_transforms = train_transforms
        self.valid_transforms = valid_transforms
        self.no_aug_transforms = no_aug_transforms
        self.image_dir = image_dir
        self.fold_num = fold_num
        self.CFG = configurations
    
    def setup(self, stage=None) -> None:
        train_df = self.df[df.fold != self.fold_num].reset_index()
        valid_df = self.df[df.fold == self.fold_num].reset_index()
        
        print(f"Size of Train Dataset: {len(train_df.index)}")
        print(f"Size of Validation Dataset: {len(valid_df.index)}")
        if stage is None or stage == 'fit':
            self.train_dataset = ImageDataset(image_names=train_df.image.values, 
                                            labels=train_df[self.CFG.labels].values, 
                                            image_dir=self.image_dir, 
                                            transforms=self.train_transforms,
                                            )

            self.valid_dataset = ImageDataset(image_names=valid_df.image.values, 
                                            labels=valid_df[self.CFG.labels].values, 
                                            image_dir=self.image_dir, 
                                            transforms=self.valid_transforms,
                                            )
        elif stage == 'test':
            self.test_dataset = ImageDataset(image_names=valid_df.image.values, 
                                            labels=valid_df[self.CFG.labels].values, 
                                            image_dir=self.image_dir, 
                                            transforms=self.valid_transforms,
                                            )
        
        
    def train_dataloader(self):
#         if self.trainer and self.trainer.current_epoch == 0:
#             print('Set to none-augmentation transforms')
#             self.train_dataset.transforms = self.no_aug_transforms
#         else:
#             print('Set to augmentation transforms')
#             self.train_dataset.transforms = self.train_transforms
        
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.CFG.batch_size,
            num_workers=self.CFG.num_workers,
            shuffle=True,
            pin_memory=True,
        )
        return train_loader

    def val_dataloader(self):        
        valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.CFG.batch_size,
            num_workers=self.CFG.num_workers,
            shuffle=False,
            pin_memory=True,
        )
        return valid_loader

    def test_dataloader(self):
        return None

In [None]:
# Hàm chuyển label dạng text list sang label vector
def one_hot_encoded_df(dataset_df):
    # copy dataframe
    dataset_df_copy = dataset_df.copy()
    unique_labels = dataset_df_copy.labels.unique()
    new_column_names = list(set(' '.join(unique_labels).split()))
    # initialize columns with zero
    dataset_df_copy[new_column_names] = 0        
    # one-hot-encoding using the column names
    for labels in unique_labels:                
        label_indices = dataset_df_copy[dataset_df_copy['labels'] == labels].index
        splited_labels = labels.split()
        dataset_df_copy.loc[label_indices, splited_labels] = 1
    return dataset_df_copy

### Load CSVs

In [None]:
dataset_df = pd.read_csv(CFG.train_csv_path)
folds_df = pd.read_csv(CFG.folds_csv_path)

df = one_hot_encoded_df(dataset_df)
df = folds_df.merge(df, on='image')
del dataset_df, folds_df
df.head()

### Data transforms

In [None]:
# Wrapper để sử dụng CoarseDropout của albumentation cùng Transform mặc định của Pytorch
class AlbWrapper(object):
    def __init__(self):
        self.tf = A.CoarseDropout(max_height=int(360 * 0.08), max_width=int(360 * 0.08), max_holes=5, p=0.3)
    def __call__(self, img):
        return Image.fromarray(self.tf(image=np.array(img))['image'])

In [None]:
# train_aug = A.Compose([
#         # A.RandomResizedCrop(360, 512, scale=(0.9, 1), p=0.5), 
#         A.Flip(p=0.5),
#         A.OneOf([ 
#             A.ShiftScaleRotate(scale_limit=(-0.1, 0.00), rotate_limit=10),
#             A.Perspective(scale=(0.05, 0.2)),
#         ],p=0.3),
#         A.OneOf([
#             A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10),
#             A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.2, 0.2)),     
#         ], p=0.3),
#         A.CoarseDropout(max_height=int(360 * 0.08), max_width=int(360 * 0.08), max_holes=5, p=0.3),
# ])

train_aug = tsfm.Compose([tsfm.RandomApply([tsfm.RandomPerspective(distortion_scale=0.2),], p=0.3),
                        tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomAffine(degrees=10),], p=0.3),
                        tsfm.RandomVerticalFlip(p=0.5),
                        tsfm.RandomHorizontalFlip(p=0.5),
                        AlbWrapper(),
                     ])

#### Plot một số ví dụ về Augmentation


In [None]:
def plot_sample(ds):
    figure, axes = plt.subplots(3, 6, figsize=[20, 8])
    for i, ax in enumerate(axes.flat):
        image, _ = ds[i]
        ax.imshow(image)
        ax.axis('off')
    plt.show()

ds = ImageDataset(image_names=df.image.values, 
                  labels=df[CFG.labels].values, 
                  image_dir=CFG.train_imgs_dir, 
                  transforms=None)

In [None]:
# No Aug
plot_sample(ds)

In [None]:
# Train Aug
ds.transforms = train_aug
plot_sample(ds)

In [None]:
del ds

#### Train và Valid Transform hoàn chỉnh

In [None]:
# train_transform = A.Compose([
#       train_aug, 
#       A.Normalize(),
#       ToTensorV2(),
# ])

# valid_transform = A.Compose([
# #     A.Resize(height=CFG2.img_size, width=CFG2.img_size, p=1.0),
#     A.Normalize(),
#     ToTensorV2(),
# ])

DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)
train_transform = tsfm.Compose([train_aug,
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

valid_transform = tsfm.Compose([tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

## Build Model

In [None]:
# f = -1
# count = df[df.fold != f][CFG.labels].values.sum(axis=0)
# pos_weight = ((len(df[df.fold != f].index) - count) / count)
# pos_weight = torch.round(torch.FloatTensor(pos_weight))
# # pos_weight = torch.clip(pos_weight, 0, 5)
# print(pos_weight)
# pos_weight = torch.log10(pos_weight) + 1
# print(pos_weight)

### Focal Loss

Implement cùng class weight và pos_weight

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class FocalLoss(nn.modules.loss._Loss):
    def __init__(self, gamma=2., alpha=1., weight = None, size_average=None, reduce=None, reduction = 'mean',
                 pos_weight = None) -> None:
        super(FocalLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)
        self.alpha = alpha
        self.gamma = gamma
        self.use_pw = True
    
    def reduce_loss(self, loss):
        return loss.mean() if self.reduction == 'mean' else loss.sum() \
         if self.reduction == 'sum' else loss

    def forward(self, input, target):
        assert self.weight is None or isinstance(self.weight, Tensor)
        assert self.pos_weight is None or isinstance(self.pos_weight, Tensor)
        logpt = -F.binary_cross_entropy_with_logits(input, target,
#                                                       self.weight,
                                                      pos_weight=self.pos_weight if self.use_pw else None,
                                                      reduction='none')
        pt = torch.exp(logpt)

        # compute the loss
        focal_loss = -( self.alpha * (1-pt)**self.gamma ) * logpt
        return self.reduce_loss(focal_loss * self.weight)

In [None]:
from pytorch_lightning.callbacks import Callback

# Callback điều chỉnh sử dụng pos weight và đổi learning rate
class RemovePosWeight(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        if pl_module.current_epoch == 1:
            pass
#             pl_module.criterion.use_pw = False
#             print('PW off')
#             trainer.optimizers[0].param_groups[0]['initial_lr'] = 5e-4
#             trainer.optimizers[0].param_groups[0]['lr'] = 5e-4
#             trainer.lr_schedulers[0]['scheduler'].base_lrs = [5e-4]
#             print(trainer.lr_schedulers)
#             print(trainer.optimizers[0])

        elif pl_module.current_epoch == 2:
#             trainer.lr_schedulers.clear()
            trainer.optimizers[0].param_groups[0]['initial_lr'] = 1e-4
            trainer.optimizers[0].param_groups[0]['lr'] = 1e-4
            trainer.lr_schedulers[0]['scheduler'].base_lrs = [1e-4]

### F1 Metrics

Được tính dựa trên quan sát từ https://www.kaggle.com/buinyi/understanding-the-evaluation-metric-cv

In [None]:
# F1 score metric
class F1Score(Metric):
    def __init__(self, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, sigmoid=True):
        assert preds.shape == target.shape
        with torch.no_grad():
            if sigmoid: preds = torch.sigmoid(preds)
            preds = (preds > self.threshold).type(torch.long)

            target_healthy = 1 - torch.clip(target.sum(dim=-1, keepdim=True), 0, 1)
            pred_healthy = 1 - torch.clip(preds.sum(dim=-1, keepdim=True), 0, 1)
            preds = torch.cat([preds, pred_healthy], -1)
            target = torch.cat([target, target_healthy], -1)

            tp = (preds*target).sum()
            fp = preds.sum() - tp
            fn = ((1 - preds)*target).sum()
        
        self.tp += tp.item()
        self.fp += fp.item()
        self.fn += fn.item()

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1

### Pytorch Lighting Module


In [None]:
# Sử dụng label smoothing
# Dựa trên https://github.com/pytorch/pytorch/issues/7455
def smooth_target(target, smoothing=0.1):
    with torch.no_grad():
        true_dist = torch.abs(target - smoothing)
    return true_dist

In [None]:
class Lit(pl.LightningModule):
    def __init__(self, cfg):
        super(Lit, self).__init__()
        self.cfg = cfg
        self.model = timm.create_model(cfg.model_name, pretrained=True, num_classes=cfg.num_classes)
        self.criterion = FocalLoss(alpha=cfg.fl_alpha, gamma=cfg.fl_gamma, 
                                   weight=torch.tensor(CFG.cls_weight), 
                                   pos_weight=torch.tensor(CFG.pos_weight))
        self.metric = F1Score()
       
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        if self.cfg.use_sgd:
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.cfg.lr, momentum=0.9)
        else:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
        
        scheduler = {
              'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
                                                                      T_0=self.cfg.t_max,
                                                                      eta_min=self.cfg.min_lr,
                                                                      verbose=False),
              'interval':'step',
              'frequency': CFG.scheduler_freq,
              'monitor': 'valid_loss',
        }
        return [self.optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        images, targets = batch
        logits = self.model(images)
        loss = self.criterion(logits, smooth_target(targets))
        score = self.metric(logits, targets)
        
        logs = {'train_loss': loss, 'lr': self.optimizer.param_groups[0]['lr']}
        self.log_dict(logs, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log( 'train_f1', score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, targets = batch
        logits = self.model(images)
        loss = self.criterion(logits, smooth_target(targets))
        score = self.metric(logits, targets)
        
        logs = {'valid_loss': loss, 'valid_f1': score}
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

## Training

Ở đây ta sẽ chỉ chạy thử một vài batches

In [None]:
# import wandb
# wandb.login(key='KEY')
from pytorch_lightning.loggers import WandbLogger

In [None]:
# Chuyển sang chế độ debug sẽ tắt logger và checkpoint callback
debug = True

In [None]:
# Khởi tạo data module
data_module = ImageDataModule(df=df[df.fold != CFG.reserve_fold],
                              train_transforms=train_transform,
                              valid_transforms=valid_transform,
                              no_aug_transforms=valid_transform,
                              image_dir= CFG.train_imgs_dir,
                              fold_num=0,
                              configurations=CFG)

In [None]:
# Train nhiều fold
for fold_idx in range(0, CFG.n_train_fold):
    # Set valid fold 
    data_module.fold_num = fold_idx
    data_module.setup('fit')
    
    # Logger và checkpoint
    if not debug:
        logger = WandbLogger(project='PP2021_0', 
                            name=f'{CFG.model_name}_f{fold_idx}_{CFG.version}',
                            id=f'{CFG.model_name}_f{fold_idx}_{CFG.version}')
        logger.log_hyperparams(CFG.__dict__)
        checkpoint_callback = ModelCheckpoint(dirpath=os.path.join('ckpt', f'{CFG.model_name}_{CFG.version}'),
                                              monitor='valid_f1',
                                              save_top_k=1,
                                              save_last=False,
                                              save_weights_only=True,
                                              filename=f'f{CFG.model_name}_' + '{epoch:02d}-{valid_f1:.4f}',
                                              verbose=False,
                                              mode='max')
    else:
        logger = checkpoint_callback = False
#     early_stop_callback = EarlyStopping(monitor='valid_loss', min_delta=CFG.early_stop_delta, patience=3, mode='min')
    
    # Trainer
    trainer = Trainer(max_epochs=CFG.num_epochs,
                      gpus=0, # Do hết thời gian chạy kaggle rồi
                      # tpu_cores=8,
                      accumulate_grad_batches=CFG.accum_grad_batch,
                      callbacks=[RemovePosWeight()],
                      checkpoint_callback=checkpoint_callback,
                      logger=logger,
                      
                      # Sẽ chỉ chạy 10 train batches và 10 val batches
                      limit_train_batches = 10,
                      limit_val_batches = 10,
                      
                      # precision=16,
                      # resume_from_checkpoint='/content/drive/MyDrive/PP2021/ckpt/tf_efficientnet_b4_ns_f0_col_pos_weight_v5/last.ckpt',
                      # reload_dataloaders_every_epoch=True,
                      weights_summary='top')
    # Model
    model = Lit(CFG)
    # for param in model.model.conv_stem.parameters():
    #     param.requires_grad = False
    # for param in model.model.blocks.parameters():
    #     param.requires_grad = False
   
    # Fit
    trainer.fit(model, data_module)
    if not debug:
        logger.close()
        wandb.finish()