In [1]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import os
import cv2
import timm
import torch
from torch import nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize, RandomCrop
)

from albumentations.pytorch import ToTensorV2

In [2]:
CFG = {
    'model_arch1': 'resnest101e',
    'model_arch2': 'tf_efficientnet_b4_ns',
    'model_arch7': 'vit_base_patch16_384',
    'img_size': 512,
    'valid_bs': 32,
    'num_workers': 4,
    'device': 'cuda:0',
    'tta': 5,
    'used_epochs_model1': [9,9,6,6,8],
    'used_epochs_model2': [11,13,13,11,12],
    'used_epochs_model7': [9,8,9,8,8],
    'weights': [1,1,1,1,1]
}
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.head()

Unnamed: 0,image_id,label
0,2216849948.jpg,4


In [3]:
def get_img(path):
    '''使用 opencv 加载图片.
    由于历史原因，opencv 读取的图片格式是 bgr
    Args:
        path : str  图片文件路径 e.g '../data/train_img/1.jpg'
    '''
    img_bgr = cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return img_rgb

In [4]:
class CassavaDataset(Dataset):
    '''木薯叶比赛数据加载类
    Attributes:
        __len__ : 数据的样本个数.
        __getitem__ : 索引函数.
    '''
    def __init__(
            self,
            df,
            data_root,
            transforms=None,
            output_label=True,
            one_hot_label=False,
            do_fmix=False,
            fmix_params={
                'alpha': 1.,
                'decay_power': 3.,
                'shape': (512, 512),
                'max_soft': 0.3,
                'reformulate': False
            },
            do_cutmix=False,
            cutmix_params={
                'alpha': 1,
            }):

        super().__init__()
        self.df = df.reset_index(drop=True).copy()  # 重新生成索引
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        if output_label:
            self.labels = self.df['label'].values
            if one_hot_label:
                self.labels = np.eye(self.df['label'].max() +
                                     1)[self.labels]  # 使用单位矩阵生成 onehot 编码

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        '''
        Args:
            index : int , 索引
        Returns:
            img, target(optional)
        '''
        if self.output_label:
            target = self.labels[index]

        img = get_img(
            os.path.join(self.data_root,
                         self.df.loc[index]['image_id']))  # 拼接地址，加载图片

        if self.transforms:  # 使用图片增强
            img = self.transforms(image=img)['image']

        if self.do_fmix and np.random.uniform(
                0., 1., size=1)[0] > 0.5:  # 50% 概率触发 fmix 数据增强

            with torch.no_grad():
                lam, mask = sample_mask(
                    **self.fmix_params)  # 可以考虑魔改，使用 clip 规定上下限制

                fmix_ix = np.random.choice(self.df.index,
                                           size=1)[0]  # 随机选择待 mix 的图片
                fmix_img = get_img(
                    os.path.join(self.data_root,
                                 self.df.loc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)

                img = mask_torch * img + (1. - mask_torch) * fmix_img  # mix 图片

                rate = mask.sum() / float(img.size)  # 获取 mix 的 rate
                target = rate * target + (
                    1. - rate) * self.labels[fmix_ix]  # target 进行 mix

        if self.do_cutmix and np.random.uniform(
                0., 1., size=1)[0] > 0.5:  # 50% 概率触发 cutmix 数据增强
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img = get_img(
                    os.path.join(self.data_root,
                                 self.df.loc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']

                lam = np.clip(
                    np.random.beta(self.cutmix_params['alpha'],
                                   self.cutmix_params['alpha']), 0.3, 0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2,
                                                        bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) *
                            (bby2 - bby1) / float(img.size))  # 获取 mix 的 rate
                target = rate * target + (
                    1. - rate) * self.labels[cmix_ix]  # target 进行 mix

        if self.output_label:
            return img, target
        else:
            return img

In [5]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained, num_classes=5)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [6]:
def inference_one_epoch(model, data_loader, device):
    model.eval()
    image_preds_all = []
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    print(pbar)
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        image_preds = model(imgs) 
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

In [7]:
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
device = torch.device(CFG['device'])

In [8]:
val_preds = []
tst_preds = []



def get_inference_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

test_ds = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_inference_transforms(), output_label=False)
tst_loader = torch.utils.data.DataLoader(
            test_ds, 
            batch_size=CFG['valid_bs'],
            num_workers=CFG['num_workers'],
            shuffle=False,
            pin_memory=False,
        )

model = CassvaImgClassifier(CFG['model_arch1'], 5).to(device)

for fold, epoch in enumerate(CFG['used_epochs_model1']):    
    model.load_state_dict(torch.load('../input/baseline210110/resnest101e_bs8x2_1e-4_0111/{}_fold_{}_{}'.format('_'.join(CFG['model_arch1'].split('_')), fold, epoch,CFG['model_arch1'], fold, epoch),
                                    map_location=CFG['device']))

    with torch.no_grad():
        for _ in range(CFG['tta']):
            tst_preds += [CFG['weights'][fold]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader, device)]

del model






# 
def get_inference_transforms_eff():
    return Compose([
            RandomCrop(512, 512),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.42984136, 0.49624753, 0.3129598], std=[0.21417203, 0.21910103, 0.19542212], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

test_ds1 = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_inference_transforms_eff(), output_label=False)
tst_loader1 = torch.utils.data.DataLoader(
    test_ds1, 
    batch_size=CFG['valid_bs'],
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=False,
)     


model = CassvaImgClassifier(CFG['model_arch2'], 5).to(device)
            
for fold, epoch in enumerate(CFG['used_epochs_model2']):    
    model.load_state_dict(torch.load('../input/train1212fold5/tf_efficientnet_b4_ns_1212_fold_{}_{}'.format(fold, epoch),
                                    map_location=CFG['device']))

    with torch.no_grad():
        for _ in range(6):
            tst_preds += [CFG['weights'][fold]/sum(CFG['weights'])/6*inference_one_epoch(model, tst_loader1, device)]

del model







# 
def get_inference_transforms_384():
    return Compose([
            RandomResizedCrop(384, 384),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)
test_ds2 = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_inference_transforms_384(), output_label=False)
tst_loader2 = torch.utils.data.DataLoader(
    test_ds2, 
    batch_size=CFG['valid_bs'],
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=False,
)     

model = CassvaImgClassifier(CFG['model_arch7'], 5).to(device)
for fold, epoch in enumerate(CFG['used_epochs_model7']):    
    model.load_state_dict(torch.load('../input/vit210211/vit_base_patch16_384_bs8x2/{}_fold_{}_{}'.format('_'.join(CFG['model_arch7'].split('_')), fold, epoch,CFG['model_arch7'], fold, epoch),
                                    map_location=CFG['device']))

    with torch.no_grad():
        for _ in range(CFG['tta']):
            tst_preds += [CFG['weights'][fold]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader2, device)]

del model

torch.cuda.empty_cache()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:01<00:00,  1.11s/it]
100%|██████████| 1/1 [00:00<00:00,  7.01it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.23it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.56it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.68it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.37it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.67it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.49it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.81it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.30it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  5.48it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.55it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.48it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.77it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.89it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.28it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.09it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.73it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.13it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.85it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.17it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  4.80it/s]
100%|██████████| 1/1 [00:00<00:00,  6.04it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  6.54it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.90it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.55it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  5.87it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.58it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  8.03it/s]
100%|██████████| 1/1 [00:00<00:00,  7.92it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  8.19it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.58it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.22it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.55it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  8.41it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.71it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.25it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.39it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.44it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.66it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  8.33it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.85it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.78it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.76it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.97it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.09it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.03it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.68it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.01it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.79it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.43it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.63it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.30it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.85it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.00it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.74it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.58it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  6.14it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.00it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.15it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.01it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.89it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  8.06it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.47it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.96it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.18it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.09it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  6.10it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  6.66it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  6.95it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.63it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.02it/s]


  0%|          | 0/1 [00:00<?, ?it/s]


100%|██████████| 1/1 [00:00<00:00,  7.95it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.70it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.87it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  7.69it/s]

  0%|          | 0/1 [00:00<?, ?it/s]



100%|██████████| 1/1 [00:00<00:00,  8.10it/s]

  0%|          | 0/1 [00:00<?, ?it/s]





In [9]:
tst_preds = np.sum(tst_preds, axis=0) 

In [10]:
test['label'] = np.argmax(tst_preds, axis=1)
test.to_csv('submission.csv', index=False)