In [1]:
import sys
sys.path.append('../input/models/pretrained-models.pytorch-master/pretrained-models.pytorch-master')
sys.path.append('../input/models/EfficientNet-PyTorch-master/EfficientNet-PyTorch-master')
sys.path.append('../input/models/pytorch-image-models-master/pytorch-image-models-master')

In [2]:
import os
import cv2
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
import torchvision
import pretrainedmodels
import timm
from efficientnet_pytorch import EfficientNet
import albumentations as A

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available()  else 'cpu'
NUM_CLASSES = 5

In [4]:
def GetPath(pth):
    return os.path.join('../input/cassavapth/', pth)

SIZE = 512

modeldefs = [
    # efficientnet-b4 (LB900), CV: k0=8981, k1ep13=9030, k2ep13=8927, k3ep12=8983, k4ep14=8904 (oof=0.89639)
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k0.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k1.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k2.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k3.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k4.pth') },
    # efficientnet-b4 (LB---), CV: k0=8951, k1=8963, k2=8923, k3=9002, k4=8897 (oof=0.89470)
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k0.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k1.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k2.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k3.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k4.pth') },
    # se_resnet101 (LB894), CV: k0=8958, k1e=8953, k2=8869, k3=8997, k4=8885 (oof=0.89330)
    # { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k0.pth') },
    # { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k1.pth') },
    # { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k2.pth') },
    # { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k3.pth') },
    # { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k4.pth') },
    # se_resnext101 (LB896), CV: k0e=8965, k1=8986, k2=8913, k3=9007, k4=8876 (oof=0.89503)
    # { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k0.pth') },
    # { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k1.pth') },
    # { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k2.pth') },
    # { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k3.pth') },
    # { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k4.pth') },
    # resnest101e (LB896), CV: k0=9005, k1=9012, k2=8960, k3=8967, k4=8920 (oof=0.89728)
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k0.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k1.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k2.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k3.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k4.pth') },
    # regnety_032 (LB896), CV: k0=8974, k1=8984, k2=8906, k3=9009, k4=8918 (oof=0.89480)
    # { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k0.pth') },
    # { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k1.pth') },
    # { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k2.pth') },
    # { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k3.pth') },
    # { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k4.pth') },
    # B5 (LB892), CV: k0=8993, k1=8965, k2=8962, k3=8986, k4=8927 (oof=-.-----)
    # { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k0.pth') },
    # { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k1.pth') },
    # { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k2.pth') },
    # { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k3.pth') },
    # { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k4.pth') },
    # resnest200e (LB---) k0=9009, k1=9002, k2=8937, k3=9021, k4=8904 (oof=0.89760)
    # { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k0.pth') },
    # { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k1.pth') },
    # { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k2.pth') },
    # { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k3.pth') },
    # { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k4.pth') },
    # resnest101e (LB894), CV: k0=8991, k1=9000, k2=8925, k3=8965, k4=8862 (oof=0.89480)
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k0.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k1.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k2.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k3.pth') },
    # { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k4.pth') },
    # efficientnet-b4 (LB895), CV: k0=8993, k1=8972, k2=8939, k=3=9021, k4=8941 (oof=0.89728)
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k0.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k1.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k2.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k3.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k4.pth') },
    # efficientnet-b4 (LB---), CV: k0=8993, k1=8972, k2=8939, k=3=9021, k4=8941 (oof=0.89629)
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k0.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k1.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k2.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k3.pth') },
    # { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k4.pth') },
    # -----------------------------------------------------------------------------------------
    # combine_set1 efficientnet-b4 (oof=0.89886)
    { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k0.pth') },
    { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k1.pth') },
    { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019725/22019725k2.pth') },
    { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k3.pth') },
    { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k4.pth') },
]

In [5]:
TTA_ROUND = 3

In [6]:
def TTA(img, ops):
    # input: NxCxHxW
    if ops == 0:
        pass
    elif ops == 1:
        img = torch.flip(img, [-1])
    elif ops == 2:
        img = torch.flip(img, [-2])
    elif ops == 3:
        img = torch.flip(img, [-1, -2])
    elif ops == 4:
        img = torch.rot90(img, 1, [2, 3])
    elif ops == 5:
        img = torch.rot90(img, 3, [2, 3])
    else:
        pass
    return img

In [7]:
def GetModel(name, param):
    num_classes = NUM_CLASSES
    if name in [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50', 'resnext101', 'wide_resnet50', 'wide_resnet101' ]:
        if name == 'resnext50' or name == 'resnext101':
            name = name + '_32x4d'
        elif name == 'wide_resnet50' or name == 'wide_resnet101':
            name = name + '_2'
        model = getattr(torchvision.models, name)(pretrained=None)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name in [ 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50', 'se_resnext101', 'se_resnext50_32x4d', 'se_resnext101_32x4d' ]:
        if name == 'se_resnext50' or name == 'se_resnext101':
            name = name + '_32x4d'
        model = getattr(pretrainedmodels, name)(pretrained=None)
        model.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        model.last_linear = nn.Linear(model.last_linear.in_features, num_classes)
    elif name.startswith('efficientnet-b'):
        model = EfficientNet.from_name(name)
        model._fc = nn.Linear(model._fc.in_features, num_classes)
    elif name.startswith('timm-'):
        model = timm.create_model(model_name=name[len('timm-'):], num_classes=num_classes, in_chans=3, pretrained=False)
    else:
        raise NameError()
    state = torch.load(param, map_location=DEVICE)
    model.load_state_dict(state, strict=True)
    model.eval()
    print('model ({}) is loaded'.format(name))
    return model

In [8]:
def GetAugment(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize()
    ], p=1.0)

In [9]:
def GetDataLoader(files, augops, batch=1, num_workers=2):
    dataset = InferDataset(files, augops=augops)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

class InferDataset(Dataset):
    def __init__(self, files, augops):
        self.files = files
        self.augops = augops

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(self.files[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        out = self.augops(force_apply=False, image=img)['image']
        out = out.transpose(2, 0, 1)
        return torch.from_numpy(out), os.path.basename(self.files[idx])

In [10]:
dfiles = glob.glob('../input/cassava-leaf-disease-classification/test_images/*.*')
SHOW_PRED = len(dfiles) == 1

In [11]:
loader = GetDataLoader(dfiles, augops=GetAugment(SIZE), batch=8)

In [12]:
models = [ ]
for mdef in modeldefs:
    mdl = GetModel(mdef['name'], mdef['pth']).to(DEVICE)
    models.append(mdl)

model (efficientnet-b4) is loaded
model (efficientnet-b4) is loaded
model (efficientnet-b4) is loaded
model (efficientnet-b4) is loaded
model (efficientnet-b4) is loaded


In [13]:
with torch.no_grad():
    names = [ ]
    preds = np.array([], dtype=np.int32)
    ratio = 1.0 / len(models) if len(models) > 0 else 1.0
    actfn = nn.Softmax(dim=1)
    for _, itr in enumerate(loader):
        x, n = itr
        b = x.shape[0]
        x = x.to(DEVICE)
        y = torch.zeros([b, NUM_CLASSES], device=DEVICE)
        for tta in range(TTA_ROUND):
            xi = TTA(x, tta)
            for model in models:
                y = y + actfn(model(xi)) * ratio
        if TTA_ROUND > 1:
            y = y / TTA_ROUND
        y = y.detach().cpu().numpy()
        if SHOW_PRED:
            p = y
        y = np.argmax(y, axis=1)
        preds = np.append(preds, y)
        names.extend(n)
if SHOW_PRED:
    for i, n in enumerate([ '    CBB', '   CBSD', '    CGM', '    CMD', 'Healthy' ]):
        print('{}={:.6}'.format(n, p[0, i]))

    CBB=0.0245013
   CBSD=0.0428515
    CGM=0.404055
    CMD=0.0324383
Healthy=0.496154


In [14]:
with open('submission.csv', mode='w') as f:
    f.write('image_id,label\n')
    for n, l in zip(names, preds):
        f.write('{},{}\n'.format(n, l))