In [None]:
import sys
sys.path.append('../input/git-repos/pretrained-models.pytorch/pretrained-models.pytorch')
sys.path.append('../input/git-repos/EfficientNet-PyTorch/EfficientNet-PyTorch')
sys.path.append('../input/git-repos/pytorch-image-models/pytorch-image-models')

In [None]:
import cv2
from skimage import io
import torch
from torch import nn
import os
import time
import random
import pandas as pd
import numpy as np
from tqdm import tqdm

from torch.utils.data import Dataset,DataLoader

from sklearn import metrics
import warnings
import timm #from efficientnet_pytorch import EfficientNet
from collections import OrderedDict
import pretrainedmodels

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

In [None]:
CFG01 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp37',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG02 = {
    'weight': 0.900,
    'model_arch': 'seresnext50_32x4d', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp39',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG04 = {
    'weight': 0.903,
    'model_arch': 'vit_base_patch16_384', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':384,
    'width':384,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-vit-best', # '../input/cassava-vit-cutmix',
    'clsw': [[0.2, 0.15, 0, 0.1, 0.1], [0.4, 0.3, 0.2, 0.35, 0.25], [0, 0, 0.3, 0.3, 0], [0, 0.15, 0.1, 0.25, 0.4], [0.4, 0.4, 0.4, 0, 0.25]],
#     'clsw': [[0.40, 0.30, 0.00, 0.20, 0.00], 
#              [0.10, 0.30, 0.40, 0.20, 0.10], 
#              [0.10, 0.10, 0.15, 0.20, 0.30], 
#              [0.00, 0.00, 0.10, 0.20, 0.40], 
#              [0.40, 0.30, 0.15, 0.20, 0.20]],
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG05 = {
    'weight': 0.900,
    'model_arch': 'tf_efficientnet_b1_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp56-b1',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG06 = {
    'weight': 0.899,
    'model_arch': 'tf_efficientnet_b2_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp56-b2',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG07 = {
    'weight': 0.900,
    'model_arch': 'tf_efficientnet_b3_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp56-b3',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG08 = {
    'weight': 0.900,
    'model_arch': 'regnety_032', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-regnety032',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}


CFG09 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp57-radam',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'all', # best last
}

CFG10 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-fenghan',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG11 = {
    'weight': 0.903,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp67-sgd',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG12 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp68-ts',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG13 = {
    'weight': 0.901,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp68-bce-best',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG14 = {
    'weight': 0.899,
    'model_arch': 'resnest50d_4s2x40d', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp70-resnest50',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG15 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':768,
    'width':768,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp67-768-16',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}
CFG16 = {
    'weight': 0.902,
    'model_arch': 'seresnext50_32x4d', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':768,
    'width':768,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp75-1',
#     'clsw': [[0.00, 0.30, 0.40, 0.00, 0.40], 
#              [0.20, 0.10, 0.30, 0.40, 0.00], 
#              [0.40, 0.20, 0.00, 0.20, 0.10], 
#              [0.40, 0.20, 0.10, 0.20, 0.30], 
#              [0.00, 0.20, 0.00, 0.20, 0.20]],
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG17 = {
    'weight': 0.902,
    'model_arch': 'regnety_032', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':768,
    'width':768,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp70-2',
#     'clsw': [[0.00, 0.30, 0.40, 0.00, 0.40], 
#              [0.20, 0.10, 0.30, 0.40, 0.00], 
#              [0.40, 0.20, 0.00, 0.20, 0.10], 
#              [0.40, 0.20, 0.10, 0.20, 0.30], 
#              [0.00, 0.20, 0.00, 0.20, 0.20]],
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFG18 = {
    'weight': 0.902,
    'model_arch': 'tf_efficientnet_b4_ns', # se_resnext50_32x4d tf_efficientnet_b0_ns seresnext50_32x4d regnety_080
    'height':512,
    'width':512,
    'valid_bs': 32,
    'tta': 3,
    'ckpt_path': '../input/cassava-exp68-03',
    'flag': False, # 使用pytorch-lightning的训练模型,就用False
    'which': 'best', # best last
}

CFGS = [CFG18, CFG02, CFG04, CFG08]
# CFGS = [CFG12, CFG15, CFG08]
WS = [cfg['weight'] for cfg in CFGS]
# WS = [1] * len(CFGS)

USE_WEIGHT = False
assert len(WS) == len(CFGS)

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

class CassavaDataset(Dataset):
    def __init__(
        self, df, data_root, transforms=None, output_label=True
    ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        
        img  = get_img(path)
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        # do label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img

# transforms

In [None]:
def get_valid_transforms():
    return Compose([
            CenterCrop(CFG['height'], CFG['width'], p=1.),
            Resize(CFG['height'], CFG['width']),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

def get_inference_transforms(mean, std):
    return Compose([
            RandomResizedCrop(CFG['height'], CFG['width']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

def get_inference_transforms2(mean, std):
    return Compose([
            RandomResizedCrop(CFG['height'], CFG['width']),
#             Transpose(p=0.5),
            HorizontalFlip(p=0.5),
#             VerticalFlip(p=0.5),
#             HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
#             RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class CassavaModelTimm(nn.Module):
    def __init__(self, model_arch, n_class=5, pretrained=False):
        super().__init__()
        self.model_arch = model_arch
        self.model = timm.create_model(model_arch, pretrained=False, num_classes=768)
#         self.gem = GeM(p=3, eps=1e-6)
#         n_features = self.model.classifier.in_features
#         self.model.classifier = nn.Linear(n_features, 768)
        self.model.metric = nn.Linear(768, n_class)
        
    def forward(self, x):
        if 'efficientnet' in self.model_arch:
            x = self.model.conv_stem(x)
            x = self.model.bn1(x)
            x = self.model.act1(x)
            x = self.model.blocks(x)
            x = self.model.conv_head(x)
            x = self.model.bn2(x)
            fea_conv = x = self.model.act2(x)
            x = self.model.global_pool(x)
#             x = self.gem(x)
            x = self.model.classifier(x)
            x = self.model.metric(x)
        elif 'seresnext' in self.model_arch:
            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.act1(x)
            x = self.model.maxpool(x)
            x = self.model.layer1(x)
            x = self.model.layer2(x)
            x = self.model.layer3(x)
            fea_conv = x = self.model.layer4(x)
            x = self.model.global_pool(x)
            x = self.model.fc(x)
            x = self.model.metric(x)
        elif 'regnet' in self.model_arch:
            x = self.model.stem(x)
            x = self.model.s1(x)
            x = self.model.s2(x)
            x = self.model.s3(x)
            fea_conv = x = self.model.s4(x)
            x = self.model.head(x)
            x = self.model.metric(x)
        elif 'resnest' in self.model_arch:
            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.act1(x)
            x = self.model.maxpool(x)
            x = self.model.layer1(x)
            x = self.model.layer2(x)
            x = self.model.layer3(x)
            fea_conv = x = self.model.layer4(x)
            x = self.model.global_pool(x)
            x = self.model.fc(x)
            x = self.model.metric(x)
        elif 'vit' in self.model_arch:
            fea_conv = x = self.model(x)

        return x

# Main Loop

In [None]:
def inference_one_epoch(model, data_loader, device):
    model.eval()
    image_preds_all = []
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)   #output = model(input)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
        
    image_preds_all = np.concatenate(image_preds_all, axis=0)

    return image_preds_all
#     return image_preds_all**0.5 # tsharp

In [None]:
seed_everything(719) # 719
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
device = torch.device('cuda:0')
final_preds = []
tst_preds = []
for idx, CFG in enumerate(CFGS):
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    if 'vit' in CFG['model_arch']:
        mean=[0.5,0.5,0.5]
        std=[0.5,0.5,0.5]
    transforms = get_inference_transforms(mean, std) if CFG['tta'] > 1 else get_valid_transforms(mean, std)
    if 'fenghan' in CFG['ckpt_path']:
        transforms = get_inference_transforms2(mean, std) if CFG['tta'] > 1 else get_valid_transforms(mean, std)
    test_ds = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=transforms, output_label=False)
    tst_loader = torch.utils.data.DataLoader(
            test_ds, 
            batch_size=CFG['valid_bs'],
            num_workers=4,
            shuffle=False,
            pin_memory=False,
        )

    ws = []
    one_model_preds = []
    for ii, name in enumerate(sorted(os.listdir(CFG['ckpt_path']))):
        if CFG['which'] == 'best':
            if 'last' in name:
                continue
        elif CFG['which'] == 'last':
            if 'last' not in name:
                continue
        ckpt_path = os.path.join(CFG['ckpt_path'], name)
        w = CFG['weight'] if USE_WEIGHT else 1
        ws.append(w)
        if 'exp75' in ckpt_path:
            model = CassavaModelTimm(CFG['model_arch']).cuda()
            state_dict = torch.load(ckpt_path)["state_dict"]
            model.load_state_dict(state_dict)
        else:
            model = timm.create_model(CFG['model_arch'], pretrained=False, num_classes=5).cuda()

            state_dict = torch.load(ckpt_path)["state_dict"]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                if 'criterion' in k:
                    continue
                if 'model' in k:
                    k = k[6:]
                if 'model' in k:
                    k = k[6:]
                new_state_dict[k] = v 
            model.load_state_dict(new_state_dict)
                
        tta_preds = []
        with torch.no_grad():
            for _ in range(CFG['tta']):
#                 tst_preds += [CFG['weights'][i]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader, device)]
                if 'clsw' not in CFG.keys():
                    tta_preds.append(w*inference_one_epoch(model, tst_loader, device))
                else:
                    tta_preds.append(w*inference_one_epoch(model, tst_loader, device)*np.array(CFG['clsw'][ii]))
    
        
        tta_preds = np.sum(tta_preds, axis=0) / CFG['tta']
        one_model_preds.append(tta_preds)

        del model
        torch.cuda.empty_cache()
    one_model_preds = np.sum(one_model_preds, axis=0) / sum(ws)
    final_preds.append(one_model_preds*WS[idx])
final_preds = np.sum(final_preds, axis=0) / sum(WS)
labels = np.argmax(final_preds, axis=1)
test['label'] = labels
test.to_csv('submission.csv', index=False)
test.head()