<a href="https://colab.research.google.com/github/quang-vo-ds/banana_leaf_disease_detection/blob/main/banana_leaf_disease_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Initial Setup

In [1]:
!pip -q install pydicom
!pip -q install timm
!pip -q install catalyst

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m446.7/446.7 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
import timm
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage import zoom
import pickle

In [14]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Vin_ML_Course/Final_Project
root_dir = os.getcwd()
train_dir = os.path.join(root_dir, "data/train_test/train")
test_dir = os.path.join(root_dir, "data/train_test/test")
save_model_dir = os.path.join(root_dir, "output/checkpoints")
save_output_dir = os.path.join(root_dir, "output")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Vin_ML_Course/Final_Project


In [4]:
train = pd.read_csv(os.path.join(train_dir, "train.csv"))
test = pd.read_csv(os.path.join(test_dir, "test.csv"))
test.head()

Unnamed: 0,id,label,label_name
0,sigatoka410.jpeg,2,sigatoka
1,sigatoka381.jpeg,2,sigatoka
2,pestalotiopsis63.jpeg,1,pestalotiopsis
3,xanthomonas565,4,xanthomonas
4,pestalotiopsis126.jpeg,1,pestalotiopsis


## Global Config

In [5]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 512,
    'epochs': 10,
    'train_bs': 32,
    'valid_bs': 32,
    'lr': 1e-4,
    'num_workers': 4,
    'accum_iter': 1, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'tta': 3,
    'used_epochs': [6,7,8,9],
    'weights': [1,1,1,1]
}

## Utils

In [6]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

## Dataset

In [7]:
class BananaDataset(Dataset):
    def __init__(self, df,
                 data_root=test_dir,
                 transforms=None,
                 output_label=True,
                 one_hot_label=False,
                ):

        super().__init__()
        self.df = df.copy()
        self.data_root = data_root
        self.transforms = transforms
        self.output_label = output_label
        self.one_hot_label = one_hot_label

        if output_label == True:
            self.labels = self.df['label'].values
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index: int):

        # get labels
        if self.output_label:
            target = self.labels[index]

        img_dir = os.path.join(self.data_root, self.df.iloc[index]['id'])
        img  = get_img(img_dir)

        if self.transforms:
            img = self.transforms(image=img)['image']

        if self.output_label == True:
            return img, target
        else:
            return img

## Image Augmentation

In [8]:
from albumentations import Normalize, Resize, Compose
from albumentations.pytorch import ToTensorV2

def get_inference_transforms():
    return Compose([
        Resize(CFG['img_size'], CFG['img_size']),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
        ], p=1.)

## Model

In [9]:
class MyImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x

## Main

In [10]:
def inference_one_epoch(model, data_loader, device):
    model.eval()

    image_preds_all = []

    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()

        image_preds = model(imgs)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]


    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

In [11]:
if __name__ == '__main__':

    seed_everything(CFG['seed'])
    tst_preds_all = []

    for fold in range(CFG['fold_num']):
        print('Inferencing with {} started'.format(fold))
        test_ds = BananaDataset(test, transforms=get_inference_transforms(), output_label=False)
        tst_loader = torch.utils.data.DataLoader(
            test_ds,
            batch_size=CFG['valid_bs'],
            num_workers=CFG['num_workers'],
            shuffle=False,
            pin_memory=False,
        )

        device = torch.device(CFG['device'])
        model = MyImgClassifier(CFG['model_arch'], train.label.nunique()).to(device)

        tst_preds = []

        for i, epoch in enumerate(CFG['used_epochs']):
            model.load_state_dict(torch.load(os.path.join(save_model_dir,'{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))))

            with torch.no_grad():
                for _ in range(CFG['tta']):
                    tst_preds += [CFG['weights'][i]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader, device)]

        tst_preds_all += [np.mean(tst_preds, axis=0)]

        del model
        torch.cuda.empty_cache()

    tst_preds_all = np.mean(tst_preds_all, axis=0)

Inferencing with 0 started


  model = create_fn(
100%|██████████| 4/4 [00:22<00:00,  5.73s/it]
100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
100%|██████████| 4/4 [00:03<00:00,  1.23it/s]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.20it/s]
100%|██████████| 4/4 [00:03<00:00,  1.03it/s]
100%|██████████| 4/4 [00:03<00:00,  1.24it/s]
100%|██████████| 4/4 [00:03<00:00,  1.22it/s]
100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:03<00:00,  1.24it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


Inferencing with 1 started


100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.11it/s]
100%|██████████| 4/4 [00:03<00:00,  1.20it/s]
100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:04<00:00,  1.00s/it]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.20it/s]
100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
100%|██████████| 4/4 [00:03<00:00,  1.20it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


Inferencing with 2 started


100%|██████████| 4/4 [00:04<00:00,  1.01s/it]
100%|██████████| 4/4 [00:03<00:00,  1.04it/s]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:04<00:00,  1.06s/it]
100%|██████████| 4/4 [00:03<00:00,  1.12it/s]
100%|██████████| 4/4 [00:03<00:00,  1.17it/s]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:04<00:00,  1.05s/it]


Inferencing with 3 started


100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:03<00:00,  1.15it/s]
100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
100%|██████████| 4/4 [00:03<00:00,  1.00it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:03<00:00,  1.16it/s]
100%|██████████| 4/4 [00:03<00:00,  1.15it/s]
100%|██████████| 4/4 [00:03<00:00,  1.11it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:03<00:00,  1.23it/s]


Inferencing with 4 started


100%|██████████| 4/4 [00:04<00:00,  1.01s/it]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:04<00:00,  1.06s/it]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:03<00:00,  1.20it/s]
100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
100%|██████████| 4/4 [00:03<00:00,  1.16it/s]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
100%|██████████| 4/4 [00:03<00:00,  1.14it/s]
100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]


## Evaluate results

In [12]:
test['preds'] = np.argmax(tst_preds_all, axis=1)
test.to_csv(os.path.join(save_output_dir, "test_pred.csv"), index=False)
test.head()

Unnamed: 0,id,label,label_name,preds
0,sigatoka410.jpeg,2,sigatoka,2
1,sigatoka381.jpeg,2,sigatoka,2
2,pestalotiopsis63.jpeg,1,pestalotiopsis,1
3,xanthomonas565,4,xanthomonas,4
4,pestalotiopsis126.jpeg,1,pestalotiopsis,1


In [13]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

print("Multi-class accuracy: ", accuracy_score(test["label"], test["preds"]))
print("F1-score: ", f1_score(test["label"], test["preds"], average="weighted"))
print("Precision: ", precision_score(test["label"], test["preds"], average="weighted"))
print("Recall: ", recall_score(test["label"], test["preds"], average="weighted"))
print("AUC: ", roc_auc_score(test["label"], tst_preds_all/np.sum(tst_preds_all, axis=1).reshape(-1,1), multi_class='ovr'))

Multi-class accuracy:  0.984375
F1-score:  0.984541538489403
Precision:  0.9853980654761905
Recall:  0.984375
AUC:  1.0
