# Directory settings

In [None]:
debug = True
# debug = False

MODEL_DIR = '../input/34t-efficientnet-b5'

In [None]:
# ====================================================
# Directory settings
# ====================================================
import os
import glob

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'

assert len(glob.glob(f'{MODEL_DIR}/*.yml'))==1
config_path = glob.glob(f'{MODEL_DIR}/*.yml')[0]

# Library

In [None]:
# ====================================================
# Library
# ====================================================
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import os
import math
import time
import random
import glob
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, CenterCrop
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

import warnings 
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CFG

In [None]:
# # ====================================================
# # CFG
# # ====================================================
import yaml

with open(config_path) as f:
    config = yaml.load(f)

INFO = config['info']
TAG = config['tag']
CFG = config['cfg']

CFG['train'] = False
CFG['inference'] = True
inference_batch_size = 8


# if not os.path.exists('__notebook__.ipynb'):
#     CFG['debug'] = True

# class CFG:
#     debug=False
#     num_workers=4
#     model_name='tf_efficientnet_b4_ns'
#     size=256
#     batch_size=32
#     seed=42
#     target_size=5
#     target_col='label'
#     n_fold=5
#     trn_fold=[0, 1, 2, 3, 4]
#     train=False
#     inference=True

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)


@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=OUTPUT_DIR+'inference.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()


def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG['seed'])

In [None]:
def get_result(result_df):
    preds = result_df['preds'].values
    labels = result_df['label'].values
    score = get_score(labels, preds)
    LOGGER.info(f'Score: {score:<.5f}')
    
    return score


def get_aug_name(compose):
    aug_list = []
    for aug in compose:
        aug_list.append(aug.__class__.__name__)
        
    return aug_list


def get_aug_score(aug_preds, labels, tta_list):
    for i in range(aug_preds.shape[1]):
        aug_pred = aug_preds[:,i,:]
        aug_list = get_aug_name(tta_list[i])
        score = get_score(labels, aug_pred.argmax(1))
        print(score, aug_list)
        LOGGER.info(f"========== aug: {aug_list} result ==========")
        LOGGER.info(f'Score: {score:<.5f}')
        
        
        
def get_aug_csv(aug_preds, oof_df, tta_list):
    for i in range(aug_preds.shape[1]):
        base_df = oof_df.copy()[['image_id', 'label', 'fold']]
        aug_pred = aug_preds[:,i,:]
        base_df[[str(c) for c in range(5)]] = aug_pred
        base_df['preds'] = aug_pred.argmax(1)
        aug_pred = aug_preds[:,i,:]
        aug_list = get_aug_name(tta_list[i])
        csv_name = '-'.join(aug_list)
        base_df.to_csv(f'{OUTPUT_DIR}{csv_name}.csv', index=False)

# Data Loading

In [None]:
test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
test.head()

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image
    
    
class TTADataset(Dataset):
    def __init__(self, df, image_path, ttas):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
        self.image_path = image_path
        self.ttas = ttas

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{self.image_path}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        imglist=[tta(image=image)['image'] for tta in self.ttas]  # update

        image=torch.stack(imglist)
        label = torch.tensor(self.labels[idx]).long()
        
        return image, label

# Transforms

In [None]:
def _get_augmentations(aug_list):
    process = []
    for aug in aug_list:
        if aug ==  'Resize':
            process.append(Resize(CFG['size'], CFG['size']))
        elif aug == 'RandomResizedCrop':
            process.append(RandomResizedCrop(CFG['size'], CFG['size']))
        elif aug == 'CenterCrop':
            process.append(CenterCrop(CFG['size'], CFG['size']))
        elif aug == 'Transpose':
            process.append(Transpose(p=0.5))
        elif aug == 'HorizontalFlip':
            process.append(HorizontalFlip(p=0.5))
        elif aug == 'VerticalFlip':
            process.append(VerticalFlip(p=0.5))
        elif aug == 'ShiftScaleRotate':
            process.append(ShiftScaleRotate(p=0.5))
        elif aug == 'Normalize':
            process.append(Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225],
                        ))
        else:
            raise ValueError(f'{aug} is not suitable')

    process.append(ToTensorV2())

    return process

In [None]:
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, aug_list):
    
    return Compose(
        _get_augmentations(aug_list)
    )

In [None]:
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

oneof_augs = [
    CenterCrop(CFG['size'], CFG['size']), 
    Resize(CFG['size'], CFG['size'])
]

ttas = [[
    Compose([
        oneof_aug,
        Normalize(mean=norm_mean, std=norm_std, p=1.),
        ToTensorV2()
    ]),
    Compose([
        oneof_aug,
        Transpose(p=1),
        Normalize(mean=norm_mean, std=norm_std, p=1.),
        ToTensorV2()
    ]),
    Compose([
        oneof_aug,
        HorizontalFlip(p=1),
        Normalize(mean=norm_mean, std=norm_std, p=1.),
        ToTensorV2()
    ]),
    Compose([
        oneof_aug,
        VerticalFlip(p=1),
        Normalize(mean=norm_mean, std=norm_std, p=1.),
        ToTensorV2()
    ])
] for oneof_aug in oneof_augs]

# 平滑化
ttas = sum(ttas, [])


In [None]:
ttas

# MODEL

In [None]:
# ====================================================
# MODEL
# ====================================================
class CustomModel(nn.Module):
    def __init__(self, model_name, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        if hasattr(self.model, 'classifier'):
            n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(n_features, CFG['target_size'])
        elif hasattr(self.model, 'fc'):
            n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(n_features, CFG['target_size'])
        elif hasattr(self.model, 'head'):
            n_features = self.model.head.in_features
            self.model.head = nn.Linear(n_features, CFG['target_size'])


    def forward(self, x):
        x = self.model(x)
        return x

# Helper functions

In [None]:
# ====================================================
# Helper functions
# ====================================================
def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images, _) in tk0:
        images = images.to(device)
        batch_size, n_crops, c, h, w = images.size()
        images = images.view(-1, c, h, w)
        
        avg_preds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
                y_preds = model(images).softmax(1)
                y_preds = y_preds.view(batch_size, n_crops,-1)
            avg_preds.append(y_preds.to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
        del images, _, y_preds, avg_preds
        torch.cuda.empty_cache()
    probs = np.concatenate(probs)
    return probs

# inference

In [None]:
# ====================================================
# inference
# ====================================================
model = CustomModel(TAG['model_name'], pretrained=False)
model_paths = glob.glob(f'{MODEL_DIR}/*.pth')
model_paths.sort()
states = [torch.load(path) for path in model_paths]
# test_dataset = TestDataset(test, transform=get_transforms(aug_list=['Resize', 'Normalize']))
test_dataset = TTADataset(test, TEST_PATH, ttas=ttas)
test_loader = DataLoader(test_dataset, batch_size=inference_batch_size, shuffle=False, 
                         num_workers=2, pin_memory=True)
predictions = inference(model, states, test_loader, device)

# stack tta
prediction = predictions.mean(1)


# submission
test['label'] = prediction.argmax(1)    
test[['image_id', 'label']].to_csv(OUTPUT_DIR+'submission.csv', index=False)
test.head()

# debug

In [None]:
def valid_inference(model, state, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images, labels) in tk0:
        images = images.to(device)
        labels = labels.to(device)
        batch_size, n_crops, c, h, w = images.size()
        images = images.view(-1, c, h, w)
        model.load_state_dict(state['model'])
        model.eval()
        with torch.no_grad():
            y_preds = model(images).softmax(1)
            y_preds = y_preds.view(batch_size, n_crops,-1)
        avg_preds = y_preds.to('cpu').numpy()
        probs.append(avg_preds)
        del images, labels, y_preds, avg_preds
        torch.cuda.empty_cache()
    probs = np.concatenate(probs)
    return probs

## oofごとのスコア

In [None]:
if debug:
#     train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv').head(100)
    train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
    folds = train.copy()
    Fold = StratifiedKFold(n_splits=CFG['n_fold'], shuffle=True, random_state=CFG['seed'])
    for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG['target_col']])):
        folds.loc[val_index, 'fold'] = int(n)
    folds['fold'] = folds['fold'].astype(int)
    
    model_paths = glob.glob(f'{MODEL_DIR}/*.pth')
    model_paths.sort()
    states = [torch.load(path) for path in model_paths]
    
    oof_df = pd.DataFrame()
    oof_aug_preds = []
    for fold, state in enumerate(states):
        
        # ====================================================
        # loader
        # ====================================================
        val_idx = folds[folds['fold'] == fold].index
        valid_folds = folds.loc[val_idx].reset_index(drop=True)
        valid_dataset = TTADataset(valid_folds, TRAIN_PATH, ttas=ttas)
        valid_loader = DataLoader(valid_dataset, 
                                  batch_size=inference_batch_size, 
                                  shuffle=False, 
                                  num_workers=CFG['num_workers'], pin_memory=True)
        valid_preds = valid_inference(model, state, valid_loader, device)
        valid_pred = valid_preds.mean(1)
        valid_folds[[str(c) for c in range(5)]] = valid_pred
        valid_folds['preds'] = valid_pred.argmax(1)
        oof_df = pd.concat([oof_df, valid_folds])
        oof_aug_preds.append(valid_preds)
        LOGGER.info(f"========== fold: {fold} result ==========")
        _ = get_result(valid_folds)
    # total score
    LOGGER.info(f"========== CV result ==========")
    score = get_result(oof_df)
    score_rem3 = get_result(oof_df.query('fold!=3'))
    
    oof_aug_preds = np.concatenate(oof_aug_preds)

## ttaごとのscore

In [None]:
if debug:
    LOGGER.info(f"========== augmentation result ==========")
    get_aug_score(oof_aug_preds, oof_df['label'], ttas)
    get_aug_csv(oof_aug_preds, oof_df, ttas)