In [1]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import os
import cv2
import timm
import torch
from torch import nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

In /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The verbose.level rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The verbose.fileo rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.


In [None]:
CFG = {
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 512,
    'valid_bs': 32,
    'num_workers': 4,
    'device': 'cuda:0',
    'tta': 3,
    'used_epochs': [5,5,5,5,5],
    'weights': [1,1,1,1]
}
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.head()

In [None]:
def get_img(path):
    img_bgr = cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return img_rgb
img = get_img('../input/cassava-leaf-disease-classification/train_images/1000015157.jpg')
plt.imshow(img)
plt.show()

In [None]:
class CassavaDataset(Dataset):
    def __init__(
            self,
            df,
            data_root,
            transforms=None,
            output_label=True,
            one_hot_label=False,
            do_fmix=False,
            fmix_params={
                'alpha': 1.,
                'decay_power': 3.,
                'shape': (512, 512),
                'max_soft': 0.3,
                'reformulate': False
            },
            do_cutmix=False,
            cutmix_params={
                'alpha': 1,
            }):

        super().__init__()
        self.df = df.reset_index(drop=True).copy()  
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        if output_label:
            self.labels = self.df['label'].values
            if one_hot_label:
                self.labels = np.eye(self.df['label'].max() +
                                     1)[self.labels] 

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        '''
        Args:
            index : int , 索引
        Returns:
            img, target(optional)
        '''
        if self.output_label:
            target = self.labels[index]

        img = get_img(
            os.path.join(self.data_root,
                         self.df.loc[index]['image_id']))  

        if self.transforms:  
            img = self.transforms(image=img)['image']

        if self.do_fmix and np.random.uniform(
                0., 1., size=1)[0] > 0.5:  

            with torch.no_grad():
                lam, mask = sample_mask(
                    **self.fmix_params)  

                fmix_ix = np.random.choice(self.df.index,
                                           size=1)[0]  
                fmix_img = get_img(
                    os.path.join(self.data_root,
                                 self.df.loc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)

                img = mask_torch * img + (1. - mask_torch) * fmix_img 

                rate = mask.sum() / float(img.size)  
                target = rate * target + (
                    1. - rate) * self.labels[fmix_ix]  

        if self.do_cutmix and np.random.uniform(
                0., 1., size=1)[0] > 0.5:  
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img = get_img(
                    os.path.join(self.data_root,
                                 self.df.loc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']

                lam = np.clip(
                    np.random.beta(self.cutmix_params['alpha'],
                                   self.cutmix_params['alpha']), 0.3, 0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2,
                                                        bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) *
                            (bby2 - bby1) / float(img.size))  
                target = rate * target + (
                    1. - rate) * self.labels[cmix_ix]  

        if self.output_label:
            return img, target
        else:
            return img

In [None]:
def get_inference_transforms():
    '''Data Augmentation TTA use in testing phase
    '''
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
def inference_one_epoch(model, data_loader, device):
    model.eval()
    image_preds_all = []
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    print(pbar)
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        image_preds = model(imgs) 
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

In [None]:
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
test_ds = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_inference_transforms(), output_label=False)

tst_loader = torch.utils.data.DataLoader(
            test_ds, 
            batch_size=CFG['valid_bs'],
            num_workers=CFG['num_workers'],
            shuffle=False,
            pin_memory=False,
        )
device = torch.device(CFG['device'])
model = CassvaImgClassifier(CFG['model_arch'], 5).to(device)

In [None]:
val_preds = []
tst_preds = []

# Select the best epoch result in the model for each fold, and do 3 tta fusions to get the final result.
for fold, epoch in enumerate(CFG['used_epochs']):    
    model.load_state_dict(torch.load('../input/{}-fold-{}-{}/{}_fold_{}_{}'.format('-'.join(CFG['model_arch'].split('_')), fold, epoch,CFG['model_arch'], fold, epoch)))

    with torch.no_grad():
        for _ in range(CFG['tta']):
            tst_preds += [CFG['weights'][fold]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader, device)]

tst_preds = np.sum(tst_preds, axis=0) 

del model
torch.cuda.empty_cache()

In [None]:
test['label'] = np.argmax(tst_preds, axis=1)
test.to_csv('submission.csv', index=False)