In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
!pip install torchcontrib

In [None]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import torch.nn as nn

from PIL import Image
from sklearn.model_selection import KFold
from torchvision import transforms as tsfm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.metrics import Metric
from torchcontrib.optim import SWA

In [None]:
class CFG:
    # dir
    root_dir_origin = "../input/plant-pathology-2021-fgvc8/"
    root_dir_resized = "../input/resized-plantpathology2021fgvc8-train-data-new/resized_plant-pathology-2021-fgvc8_train_data"
    train_csv_path = os.path.join(root_dir_origin, 'train.csv')
    train_imgs_dir = os.path.join(root_dir_resized, 'resized_train_images_360_512')
    # data info
    label_num2str = {0: 'powdery_mildew',
                     1: 'scab',
                     2: 'complex',
                     3: 'frog_eye_leaf_spot',
                     4: 'rust'}
    
    label_str2num = {'powdery_mildew': 0,
                     'scab': 1,
                     'complex': 2,
                     'frog_eye_leaf_spot': 3,
                     'rust': 4}
    # model info
    model_name = 'tf_efficientnet_b4_ns'
    pretrained_dir = '../input/efficientb4-focalloss'
    which_to_load = 'best_perform'  # last or best_perform
    needed_fold = [0, 1, 2, 3, 4, 5]
    #
    seed = 77
    num_classes = 5
    img_size = [360, 512]
    n_fold = 6
    batch_size = 8
    num_workers = 8
    fl_alpha = 1.0  
    fl_gamma = 2.0

In [None]:
"""
Define dataset class
"""
class PlantDataset(Dataset):
    def __init__(self, cfg, img_names: list, labels: list, transform=None):
        self.img_dir = cfg.train_imgs_dir
        self.img_names = img_names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        img = Image.open(img_path).convert('RGB')
        img_ts = self.transform(img)
        label_ts = self.labels[idx]
        return img_ts, label_ts

In [None]:
train_df = pd.read_csv(CFG.train_csv_path)

train_df

In [None]:
train_transform = tsfm.Compose([tsfm.Resize(CFG.img_size),
                                      tsfm.ToTensor(),
                                      tsfm.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])
valid_transform = tsfm.Compose([tsfm.Resize(CFG.img_size),
                                      tsfm.ToTensor(),
                                      tsfm.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])

In [None]:
"""
Split train & validation into Cross-Validation Folds
"""

all_img_names: list = train_df["image"].values.tolist()
all_img_labels: list = train_df["labels"].values.tolist()
print("Befor reomve duplicates:", len(all_img_names), ", ", len(all_img_labels))


    
"""
Remove duplicated samples from the training image
"""
dplct_csv_path = "../input/duplicate-images-csv/duplicates.csv"
dplct_pd = pd.read_csv(dplct_csv_path)
dplct_img_names = dplct_pd.iloc[:, 0].values.tolist() + dplct_pd.iloc[:, 1].values.tolist()
dplct_img_names = list(set(dplct_img_names))
print("Num of duplicated samples: ", len(dplct_img_names))

img_names_no_dplct = []
img_labels_no_dplct = []
for img_name, img_label in zip(all_img_names, all_img_labels):
    if img_name not in dplct_img_names:
        img_names_no_dplct.append(img_name)
        img_labels_no_dplct.append(img_label)
        
all_img_names = img_names_no_dplct
all_img_labels = img_labels_no_dplct
print("After reomve duplicates:", len(all_img_names), ", ", len(all_img_labels))
    
all_img_labels_ts = []
for tmp_lb in all_img_labels:
    tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
    for str_lb in tmp_lb.split(sep=" "):
        if str_lb != 'healthy':
            tmp_label[CFG.label_str2num[str_lb]] = 1.0
    all_img_labels_ts.append(tmp_label)
    
k_fold = KFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

In [None]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            if len(lb_str_list) == 0:
                lb_str_list = ['healthy']
            batch_str_list.append(lb_str_list)
        return batch_str_list

In [None]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error 
    
    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = F.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt =  target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        return torch.mean(focal_loss)

In [None]:
"""
Define neural network model
"""

class MyNetwork(pl.LightningModule):
    def __init__(self, cfg):
        super(MyNetwork, self).__init__()
        self.cfg = cfg
        self.model = timm.create_model(cfg.model_name, pretrained=True, num_classes=cfg.num_classes)
        self.criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
        self.metric = MyF1Score(cfg)
       
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        if self.cfg.use_swa:
            self.optimizer = SWA(torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr))
        else:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
            
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                    T_max=self.cfg.t_max,
                                                                    eta_min=self.cfg.min_lr,
                                                                    verbose=True)
        return {'optimizer': self.optimizer, 'lr_scheduler': self.scheduler}
    
    def training_step(self, batch, batch_idx):
        img_ts, lb_ts = batch
        pred_ts = self.model(img_ts)
        loss = self.criterion(pred_ts, lb_ts)
        score = self.metric(pred_ts, lb_ts)
        logs = {'train_loss': loss, 'train_f1': score, 'lr': self.optimizer.param_groups[0]['lr']}
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        img_ts, lb_ts = batch
        pred_ts = self.model(img_ts)
        loss = self.criterion(pred_ts, lb_ts)
        score = self.metric(pred_ts, lb_ts)
        logs = {'valid_loss': loss, 'valid_f1': score}
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

In [None]:
for fold_idx, (train_indices, valid_indices) in enumerate(k_fold.split(all_img_names)):    
    """
    Init dataset & dataloader
    """
    # get image names and labels
    fold_train_img_names = [all_img_names[idx] for idx in train_indices]
    fold_valid_img_names = [all_img_names[idx] for idx in valid_indices]
    fold_train_img_labels = [all_img_labels_ts[idx] for idx in train_indices]
    fold_valid_img_labels = [all_img_labels_ts[idx] for idx in valid_indices]
    # dataset
    train_dataset = PlantDataset(CFG, fold_train_img_names, fold_train_img_labels, train_transform)
    valid_dataset = PlantDataset(CFG, fold_valid_img_names, fold_valid_img_labels, valid_transform)
    # dataloader
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)


In [None]:
"""
Init models
"""
models_list = []
for fold_idx in CFG.needed_fold:
    ckpt_path = os.path.join(CFG.pretrained_dir,
                             f"fold{fold_idx}_logs/{CFG.model_name}/version_0/checkpoints/{CFG.which_to_load}.ckpt")
    
    model = MyNetwork.load_from_checkpoint(ckpt_path, cfg=CFG)
    model.cuda()
    model.eval()
    models_list.append(model)

In [None]:
import matplotlib.pyplot as plt 
%matplotlib inline

threshold = np.array([0.4333, 0.4333, 0.4333, 0.4333, 0.4333])

def convert_num_to_str(pred: np.ndarray) -> str:
    """convert the numerical labels to string labels"""
    lb_str_list = []
    for lb_idx, bool_val in enumerate(pred):
        if bool_val:
            lb_str = CFG.label_num2str[lb_idx]
            lb_str_list.append(lb_str)
    if len(lb_str_list) == 0:
        final_label = 'healthy'
    else:
        final_label = ' '.join(lb_str_list)
    return final_label

preds_list = []
with torch.no_grad():
    for img_ts, lb_ts in valid_loader:
        img_ts = img_ts.cuda()
        n_fold_pred_list = []
        for model in models_list:
            pred_ts = torch.sigmoid(model(img_ts)).detach().cpu()
            n_fold_pred_list.append(pred_ts)
        pred_np_stack = torch.stack([item for item in n_fold_pred_list], dim=2)
        pred_np = pred_np_stack.mean(dim=2)
        
        preds = (pred_np > torch.from_numpy(threshold))
        
        # convert numerical label into string
#         final_labels = [convert_num_to_str(pred) for pred in preds]
#         final_true_labels = [convert_num_to_str(lb_t) for lb_t in lb_ts]
        for pred in preds:
            preds_list.append(convert_num_to_str(pred))
        
        
#         for i in range(CFG.batch_size):
#             fig = plt.figure(figsize=(10,10 ))
#             ax = fig.add_subplot(8, 1, i+1, xticks=[], yticks=[])
#             ax.imshow(img_ts[i].permute(1,2,0).cpu())
#             ax.set_title(f'Pred: {final_labels[i]}, Label: {final_true_labels[i]}')

In [None]:
from sklearn.metrics import confusion_matrix

#cplx,cplx frog,...,cplx rust,frog,frog rust,hthy,pow,cplx pow,...,rust,scab,...,cplx frog scab, frog scab
true_labels_list = [convert_num_to_str(label) for label in valid_dataset.labels]
confusion_matrix(true_labels_list, preds_list)

In [None]:
import itertools


labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in true_labels_list]))

labels_combine = {}
for comb in true_labels_list:
    labels_combine[comb] = labels_combine.get(comb, 0) + 1

show_counts = '\n'.join(sorted(f'\t{k}: {v}' for k, v in labels_combine.items()))
print(f"unique combinations: \n" + show_counts)
print(f"total: {sum(labels_combine.values())}")

In [None]:
! pip install -q scikit-plot

In [None]:
import scikitplot as skplt

skplt.metrics.plot_confusion_matrix(
    true_labels_list, 
    preds_list,
    figsize=(8,8),
    normalize=True)

In [None]:
from sklearn.metrics import classification_report

names = ['complex', 'complex frog_eye_leaf_spot', 'complex frog_eye_leaf_spot rust', 'complex rust', 'frog_eye_leaf_spot', 'frog_eye_leaf_spot rust', 'healthy', 'powdery_milder', 'complex powdery_milder', 'powdery_milder rust', 'rust', 'scab', 'scab complex', 'complex frog_eye_leaf_spot scab', 'frog_eye_leaf_spot scab'] 

print(classification_report(true_labels_list, preds_list, target_names=names))