In [None]:
!pip install "../input/pytorchlightning110/pytorch_lightning-1.1.0-py3-none-any.whl"

import os
import sys
sys.path.append("../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master")
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
sys.path.append("../input/microsoftvision")

In [None]:
if os.getcwd() in ["/kaggle/working", "/content"]:  
    sys.path.append("../input/cassava-code")
else:
    sys.path.append("../kaggle_Cassava/code")

In [None]:
import cv2
import glob
import json
import math
import yaml
import pickle
import random
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, log_loss
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import timm
import pytorch_lightning as pl
from efficientnet_pytorch import EfficientNet
from deit_models import deit_base_patch16_224
import microsoftvision
import warnings
warnings.filterwarnings("ignore")

In [None]:
print("timm version:", timm.__version__)
print("pytorch_lightning version:", pl.__version__)

In [None]:
DEBUG = False

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

In [None]:
def set_seed(seed: int=0) -> None:
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
# https://pypi.org/project/microsoftvision/
class MicrosoftVisionResnet50(nn.Module):
    def __init__(self, pretrained=False, n_classes=5):
        super().__init__()
        self.model = microsoftvision.models.resnet50(pretrained=pretrained)
        self.fc = nn.Linear(2048, n_classes)
        
    def forward(self, x):
        x = self.model(x)
        x = self.fc(x)
        return x

In [None]:
class CassavaLite(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        if source == 'efficientnet-pytorch':
            self.net = EfficientNet.from_name(arch)
            self.net._fc = nn.Linear(self.net._fc.in_features, n_classes)
                
        elif source == 'timm':
            self.net = timm.create_model(arch, pretrained=None)
            if "eff" in arch:
                self.net.classifier = nn.Linear(self.net.classifier.in_features, n_classes) 
            elif "rexnet" in arch:
                self.net.head.fc = nn.Linear(self.net.head.fc.in_features, n_classes)
            elif "vit" in arch:
                self.net.head = nn.Linear(self.net.head.in_features, n_classes)
            else:
                self.net.fc = nn.Linear(self.net.fc.in_features, n_classes)
    
    def forward(self,x):
        out = self.net(x)
        return out

class CassavaLite2(pl.LightningModule):
    def __init__(self):
        super().__init__()
   
        self.net = timm.create_model(arch, pretrained=None)

        if "eff" in arch:
            self.net.classifier = nn.Linear(self.net.classifier.in_features, n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-2])  
        elif "rexnet" in arch:
            self.net.head.fc = nn.Linear(self.net.head.fc.in_features, n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-1])
        elif "vit" in arch:
            self.net.head = nn.Linear(self.net.head.in_features, n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-1])
        else:
            self.net.fc = nn.Linear(self.net.fc.in_features, n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-2])  # global_poolとfc層除く

    def forward(self, x):
        out = self.net(x)
        return out
    
class CassavaLite3(pl.LightningModule):
    def __init__(self, CFG):
        super().__init__()

        self.net = timm.create_model(CFG.arch, pretrained=None)

        if "eff" in CFG.arch:
            self.net.classifier = nn.Linear(
                self.net.classifier.in_features, CFG.n_classes
            )
            if CFG.gem_p > 0.0:
                self.net = GeMNet(
                    list(self.net.children())[:-2],
                    GeM(p=CFG.gem_p),
                    self.net.classifier.in_features,
                    CFG.n_classes,
                )
            self.feat_net = nn.Sequential(*list(self.net.children())[:-2])
            
        elif "rexnet" in CFG.arch:
            self.net.head.fc = nn.Linear(self.net.head.fc.in_features, CFG.n_classes)
            if CFG.gem_p > 0.0:
                self.net = GeMNet(
                    list(self.net.children())[:-1],
                    GeM(p=CFG.gem_p),
                    self.net.head.fc.in_features,
                    CFG.n_classes,
                )
            self.feat_net = nn.Sequential(*list(self.net.children())[:-1])

        elif "vit" in CFG.arch:
            self.net.head = nn.Linear(self.net.head.in_features, CFG.n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-1])

        else:
            self.net.fc = nn.Linear(self.net.fc.in_features, CFG.n_classes)
            if CFG.gem_p > 0.0:
                self.net = GeMNet(
                    list(self.net.children())[:-2],
                    GeM(p=CFG.gem_p),
                    self.net.fc.in_features,
                    CFG.n_classes,
                )
            self.feat_net = nn.Sequential(
                *list(self.net.children())[:-2]
            )  # global_poolとfc層除く

    def forward(self, x):
        out = self.net(x)
        return out
    
class CassavaLite4(pl.LightningModule):
    def __init__(self, CFG):
        super().__init__()

        if "microsoft" in CFG.arch:
            self.net = MicrosoftVisionResnet50(pretrained=False, n_classes=CFG.n_classes)
            self.feat_net = nn.Sequential(*list(self.net.children())[:-1])
            
        else:
            self.net = timm.create_model(CFG.arch, pretrained=None)

            if "eff" in CFG.arch:
                self.net.classifier = nn.Linear(
                    self.net.classifier.in_features, CFG.n_classes
                )
                self.feat_net = nn.Sequential(*list(self.net.children())[:-2])

            elif "rexnet" in CFG.arch:
                self.net.head.fc = nn.Linear(self.net.head.fc.in_features, CFG.n_classes)
                self.feat_net = nn.Sequential(*list(self.net.children())[:-1])

            elif "vit" in CFG.arch:
                self.net.head = nn.Linear(self.net.head.in_features, CFG.n_classes)
                self.feat_net = nn.Sequential(*list(self.net.children())[:-1])

            elif "deit" in CFG.arch:
                self.net = deit_base_patch16_224(pretrained=False)  # pretrained=True にすると失敗する。。。
                self.net.head = nn.Linear(self.net.head.in_features, CFG.n_classes)
                self.feat_net = nn.Sequential(*list(self.net.children())[:-1])

            else:
                self.net.fc = nn.Linear(self.net.fc.in_features, CFG.n_classes)
                self.feat_net = nn.Sequential(
                    *list(self.net.children())[:-2]
                )  # global_poolとfc層除く

    def forward(self, x):
        out = self.net(x)
        return out

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df:pd.DataFrame, imfolder:str, transformgroups):
        self.df = df
        self.imfolder = imfolder
        self.transformgroups = transformgroups
        
    def __getitem__(self,index):
        im_path = os.path.join(self.imfolder,self.df.iloc[index]['image_id'])
        x = cv2.imread(im_path,cv2.IMREAD_COLOR)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
        ret = { }
        for i, g in enumerate(self.transformgroups):
            v = None
            for t in g:
                r = t(image=x)['image'].unsqueeze(0)
                v = r if v is None else torch.cat([v, r], axis=0)
            ret[f'x{i}'] = v.float()
        return ret
        
    def __len__(self):
        return len(self.df)

In [None]:
def get_models(model_paths, lite_type, CFG=None):
    models = []
    for model_path in model_paths:
        if lite_type == "CassavaLite":
            model = CassavaLite().load_from_checkpoint(model_path).to(device)
        elif lite_type == "CassavaLite2":
            model = CassavaLite2().load_from_checkpoint(model_path).to(device)
        elif lite_type == "CassavaLite3":
            model = CassavaLite3(CFG).load_from_checkpoint(model_path, CFG=CFG).to(device)
        elif lite_type == "CassavaLite4":
            model = CassavaLite4(CFG).load_from_checkpoint(model_path, CFG=CFG).to(device)
        model.eval()
        models.append(model)
        print(f'{model_path} is loaded')
    return models

def get_dataloader(img_size_desc):
    transforms = [ ]
    for item in img_size_desc:
        # model = item[0]
        tta_nums = item[1]
        img_size = item[2]
        if img_size < 500:
            t = A.Compose([
                A.RandomResizedCrop(img_size, img_size, 
                #    p=0.5
                ),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0,
                ),
                ToTensorV2(p=1.0),
            ],
            p=1.0)
        else:
            t = A.Compose([
                A.CenterCrop(img_size, img_size, 
                #    p=0.5
                ),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0,
                ),
                ToTensorV2(p=1.0),
            ], p=1.0)
        transforms.append([ t for _ in range(tta_nums) ])
    test_dataset = CassavaDataset(test_df, img_dir, transforms)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False)
    return test_loader

# main

In [None]:
n_classes = 5
num_workers = 2

path = "../input/cassava-leaf-disease-classification/"

if DEBUG:
    test_df = pd.read_csv(f"{path}/train.csv")
    test_df = test_df.iloc[:100, :]
    img_dir = f'{path}train_images/'
    COMMIT  = False
else:
    test_df = pd.read_csv(f"{path}/sample_submission.csv")
    img_dir = f'{path}test_images/'
    COMMIT  = len(glob.glob(f'{path}test_images/*.jpg')) == 1

with open(f'{path}/label_num_to_disease_map.json', 'r') as f:
    name_mapping = json.load(f)
name_mapping = {int(k): v for k, v in name_mapping.items()}

In [None]:
n_tta = 5
n_tta_many = 8
#batch_size = 16
batch_size = 1
test_preds_list = []

In [None]:
def randomTTA(x):
    if np.random.rand() > 0.5:
        x = torch.flip(x, [-1])
    if np.random.rand() > 0.5:
        x = torch.flip(x, [-2])
    return x

In [None]:
class Config:
    def __init__(self, **kwargs):
        self.seeds = [42]
        self.n_classes = n_classes
        self.arch = kwargs["arch"]
        self.gem_p = kwargs["gem_p"]
        self.height = kwargs["height"]

def GetYaml(yamlpath):
    with open(yamlpath, 'r') as f:
        cdict = yaml.load(f)
    cfg = Config(**cdict)
    return cfg

In [None]:
%%time

# tf_efficientnet_b4_ns
# oof 0.894
cfg1 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_BiTemperedLoss/hparams.yaml")
mdl1 = sorted(
    glob.glob(
        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_BiTemperedLoss/model_seed_*.ckpt"
    )
)
models1 = get_models(mdl1, "CassavaLite3", CFG=cfg1)


# byol + seresnext50_32x4d
# oof -.---
cfg2 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_byol_seresnext50_32x4d_cutmix_labelsmooth_half/cfg.yaml")
mdl2 = sorted(
    glob.glob(
        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_byol_seresnext50_32x4d_cutmix_labelsmooth_half/model_seed_*.ckpt"
    )
)
models2 = get_models(mdl2, "CassavaLite4", CFG=cfg2)


# vit_b16_224_fold10
# oof -.---
cfg3 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_vit_b16_224_fold10/kaggle_upload_vit_b16_224_fold10/cfg.yaml")
mdl3 = sorted(
    glob.glob(
        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_vit_b16_224_fold10/kaggle_upload_vit_b16_224_fold10/model_seed_*.ckpt"
    )
)
models3 = get_models(mdl3, "CassavaLite4", CFG=cfg3)


## seresnext50_32x4d_fmix
## oof -.---
#cfg_se50 = Config(**{"arch":"seresnext50_32x4d", "gem_p":0, "height":512})
#mdl_se50 = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_seresnext50_32x4d_fmix/model_seed_*.ckpt"
#    )
#)
#models_se50 = get_models(mdl_se50, "CassavaLite4", CFG=cfg_se50)


## resnest101e
## oof 
#cfg_r101e = Config(**{"arch":"resnest101e", "gem_p":0, "height":512})
#mdl_r101e = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_resnest101e_bi_tempered_loss/model_seed_*.ckpt"
#    )
#)
#models_r101e = get_models(mdl_r101e, "CassavaLite3", CFG=cfg_r101e)


## resnest101e_cleanlab
## oof 
#cfg_r101e_c = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_resnest101e_cleanlab/cfg.yaml")
#mdl_r101e_c = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_resnest101e_cleanlab/model_seed_*.ckpt"
#    )
#)
#models_r101e_c = get_models(mdl_r101e_c, "CassavaLite4", CFG=cfg_r101e_c)


## resnest101e_cleanlab_noise_cutmix
## oof -.---
#cfg_r101e_cnc = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_resnest101e_cleanlab_noise_cutmix/cfg.yaml")
#mdl_r101e_cnc = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_resnest101e_cleanlab_noise_cutmix/model_seed_*.ckpt"
#    )
#)
#models_r101e_cnc = get_models(mdl_r101e_cnc, "CassavaLite4", CFG=cfg_r101e_cnc)


## tf_efficientnet_b4_ns_fold3
## oof -.---
#cfg_b4_f3 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_fold3/cfg.yaml")
#mdl_b4_f3 = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_fold3/model_seed_*.ckpt"
#    )
#)
#models_b4_f3 = get_models(mdl_b4_f3, "CassavaLite4", CFG=cfg_b4_f3)


## tf_efficientnet_b4_ns_fold10
## oof -.---
#cfg_b4_f10 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_fold10/cfg.yaml")
#mdl_b4_f10 = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_fold10/model_seed_*.ckpt"
#    )
#)
#models_b4_f10 = get_models(mdl_b4_f10, "CassavaLite4", CFG=cfg_b4_f10)


## tf_efficientnet_b4_ns_cleanlab_noise_cutmix_fmix_n_over
## oof -.---
#cfg_b4_cncfn = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_cleanlab_noise_cutmix_fmix_n_over/cfg.yaml")
#mdl_b4_cncfn = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_tf_efficientnet_b4_ns_cleanlab_noise_cutmix_fmix_n_over/model_seed_*.ckpt"
#    )
#)
#models_b4_cncfn = get_models(mdl_b4_cncfn, "CassavaLite4", CFG=cfg_b4_cncfn)


## vit-base-patch16-224
## oof -.---
#cfg_vp16 = Config(**{"arch":"vit_base_patch16_224", "gem_p":0, "height":224})
#mdl_vp16 = sorted(
#    glob.glob(
#        "../input/cassava-vit-base-patch16-224-fit/model_seed_*.ckpt"
#    )
#)
#models_vp16 = get_models(mdl_vp16, "CassavaLite4", CFG=cfg_vp16)


## vit-base-patch32-384
## oof -.---
#cfg_vp32 = GetYaml("../input/cassava-vit-base-patch32-384-fit/cfg.yaml")
#mdl_vp32 = sorted(
#    glob.glob(
#        "../input/cassava-vit-base-patch32-384-fit/model_seed_*.ckpt"
#    )
#)
#models_vp32 = get_models(mdl_vp32, "CassavaLite4", CFG=cfg_vp32)


## deit-base-patch16-224
## oof -.---
#cfg_deit_224 = GetYaml("../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_deit-base-patch16-224/cfg.yaml")
#mdl_deit_224 = sorted(
#    glob.glob(
#        "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_deit-base-patch16-224/model_seed_*.ckpt"
#    )
#)
#models_deit_224 = get_models(mdl_deit_224, "CassavaLite4", CFG=cfg_deit_224)


## microsoftvisionresnet50
## oof -.---
#cfg_mic = GetYaml("../input/cassava-microsoftvisionresnet50-ipynb/cfg.yaml")
#mdl_mic = sorted(
#    glob.glob(
#        "../input/cassava-microsoftvisionresnet50-ipynb/model_seed_*.ckpt"
#    )
#)
#models_mic = get_models(mdl_mic, "CassavaLite4", CFG=cfg_mic)

In [None]:
models = [
    (models1, n_tta, 512),
    (models2, n_tta, cfg2.height),
    (models3, n_tta_many, cfg3.height),
]
    #(models1, n_tta, 512),
    #(models2, n_tta, cfg2.height),
    #(models3, n_tta_many, cfg3.height),
    #(models_se50, n_tta, cfg_se50.height)
    #(models_r101e, n_tta, cfg_r101e.height),
    #(models_r101e_c, n_tta, cfg_r101e_c.height),
    #(models_r101e_cnc, n_tta, cfg_r101e_cnc.height),
    #(models_b4_f3, n_tta, cfg_b4_f3.height),
    #(models_b4_f10, n_tta, cfg_b4_f10.height),
    #(models_b4_cncfn, n_tta, cfg_b4_cncfn.height),
    #(models_vp16, n_tta_many, cfg_vp16.height),
    #(models_vp32, n_tta_many, cfg_vp32.height),
    #(models_deit_224, n_tta_many, cfg_deit_224.height),
    #(models_mic, n_tta, cfg_mic.height),
    
models_num = len(models)

In [None]:
loader = get_dataloader(models)

In [None]:
test_preds_list = []
set_seed(seed=42)
with torch.no_grad():
    preds = [ [ ] for _ in range(len(models)) ]
    actfn = nn.Softmax(dim=1)
    for batch in loader:
        for i, mdesc in enumerate(models):
            mdlgrp = mdesc[0]
            numtta = mdesc[1]
            img = batch['x{}'.format(i)][0].to(device)
            b = img.shape[0]
            x = img
            y = torch.zeros([b, 5], device=device)
            for j in range(b):
                x[j] = randomTTA(img[j])
            nummdl = len(mdlgrp)
            for mdl in mdlgrp:
                z = mdl(x)
                y = y + (actfn(z) / nummdl)
            y = y.detach().cpu().numpy()
            y = np.mean(y, axis=0)
            preds[i] += [ y ]
    for i in range(models_num):
        preds[i] = np.concatenate(preds[i], axis=0).reshape(-1, 5)
        test_preds_list.append(preds[i])

### --- SiNpcw Module ---

In [None]:
# image size
SIZE = 512

# class
NUM_CLASSES = 5

In [None]:
def GetPath(pth):
    return os.path.join('../input/cassavapth/', pth)

modeldefs = [
    # Gr1.
    [
        # B4+FINE+UDA+TTA3 (LB900), CV: k0=8981, k1ep13=9030, k2ep13=8927, k3ep12=8983, k4ep14=8904
        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k0.pth') },
        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k1.pth') },
        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k2.pth') },
        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k3.pth') },
        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019613/22019613k4.pth') },
    ],
    # Gr2.
    [
        # resnest101e (LB896), CV: k0=9005, k1=9012, k2=8960, k3=8967, k4=8920
        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k0.pth') },
        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k1.pth') },
        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k2.pth') },
        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k3.pth') },
        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019632/22019632k4.pth') },
    ],
#    [
#        # B4 (LB---), CV: k0=8951, k1=8963, k2=8923, k3=9002, k4=8897
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k0.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k1.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k2.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k3.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019629/22019629k4.pth') },
#    ],
#    [
#        # se_resnext101 (LB896), CV: k0e=8965, k1=8986, k2=8913, k3=9007, k4=8876
#        { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k0.pth') },
#        { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k1.pth') },
#        { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k2.pth') },
#        { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k3.pth') },
#        { 'name' : 'se_resnext101', 'pth' : GetPath('22019631/22019631k4.pth') },
#    ],
#    [
#        # se_resnet101 (LB894), CV: k0=8958, k1e=8953, k2=8869, k3=8997, k4=8885
#        { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k0.pth') },
#        { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k1.pth') },
#        { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k2.pth') },
#        { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k3.pth') },
#        { 'name' : 'se_resnet101', 'pth' : GetPath('22019630/22019630k4.pth') },
#    ],
#    [
#        # regnety_032 (LB896), CV: k0=8974, k1=8984, k2=8906, k3=9009, k4=8918
#        { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k0.pth') },
#        { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k1.pth') },
#        { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k2.pth') },
#        { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k3.pth') },
#        { 'name' : 'timm-regnety_032', 'pth' : GetPath('22019638/22019638k4.pth') },
#    ],
#    [
#        # B5 (LB892), CV: k0=8993, k1=8965, k2=8962, k3=8986, k4=8927
#        { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k0.pth') },
#        { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k1.pth') },
#        { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k2.pth') },
#        { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k3.pth') },
#        { 'name' : 'efficientnet-b5', 'pth' : GetPath('22019639/22019639k4.pth') },
#    ],
#    [
#        # resnest200e (LB---) k0=9009, k1=9002, k2=8937, k3=9021, k4=8904
#        { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k0.pth') },
#        { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k1.pth') },
#        { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k2.pth') },
#        { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k3.pth') },
#        { 'name' : 'timm-resnest200e', 'pth' : GetPath('22019640/22019640k4.pth') },
#    ],
#    [
#        # resnest101e (LB---), CV: k0=8991, k1=9000, k2=8925, k3=8965, k4=8862 (oof=0.89480)
#        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k0.pth') },
#        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k1.pth') },
#        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k2.pth') },
#        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k3.pth') },
#        { 'name' : 'timm-resnest101e', 'pth' : GetPath('22019717/22019717k4.pth') },
#    ],
#    [
#        # efficientnet-b4 (LB---), CV: k0=8993, k1=8972, k2=8939, k=3=9021, k4=8941 (oof=0.89728)
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k0.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k1.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k2.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k3.pth') },
#        { 'name' : 'efficientnet-b4', 'pth' : GetPath('22019720/22019720k4.pth') },
#    ],
#    [
#        # efficientnet-b4 (LB---), CV: k0=8993, k1=8972, k2=8939, k=3=9021, k4=8941 (oof=0.89629)
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k0.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k1.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k2.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k3.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k4.pth")},
#    ],
#    [
#        # combine_set1 (oof=0.89886)
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k0.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019613/22019613k1.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019725/22019725k2.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019720/22019720k3.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("22019720/22019720k4.pth")},
#    ],
#    [
#        # efficientnet-b4 (oof=0.89961)
#        {"name": "efficientnet-b4", "pth": GetPath("ex02C/ex02Ck0.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("ex02C/ex02Ck1.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("ex02C/ex02Ck2.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("ex02C/ex02Ck3.pth")},
#        {"name": "efficientnet-b4", "pth": GetPath("ex02C/ex02Ck4.pth")},
#    ],
]

modelttas = [
    3,
    3,
]

TTA_ROUND = np.max(modelttas)

In [None]:
def GetModel(name, param):
    num_classes = NUM_CLASSES
    if name in [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50', 'resnext101', 'wide_resnet50', 'wide_resnet101' ]:
        if name == 'resnext50' or name == 'resnext101':
            name = name + '_32x4d'
        elif name == 'wide_resnet50' or name == 'wide_resnet101':
            name = name + '_2'
        model = getattr(torchvision.models, name)(pretrained=None)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name in [ 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50', 'se_resnext101', 'se_resnext50_32x4d', 'se_resnext101_32x4d' ]:
        if name == 'se_resnext50' or name == 'se_resnext101':
            name = name + '_32x4d'
        model = getattr(pretrainedmodels, name)(pretrained=None)
        model.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        model.last_linear = nn.Linear(model.last_linear.in_features, num_classes)
    elif name.startswith('efficientnet-b'):
        model = EfficientNet.from_name(name)
        model._fc = nn.Linear(model._fc.in_features, num_classes)
    elif name.startswith('timm-'):
        model = timm.create_model(model_name=name[len('timm-'):], num_classes=num_classes, in_chans=3, pretrained=False)
    else:
        raise NameError()
    state = torch.load(param, map_location=device)
    model.load_state_dict(state, strict=True)
    model.eval()
    print('model ({}) is loaded'.format(name))
    return model

In [None]:
def GetAugment(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize()
    ], p=1.0)

In [None]:
def TTA(img, ops):
    # input: NxCxHxW
    if ops == 0:
        pass
    elif ops == 1:
        img = torch.flip(img, [-1])
    elif ops == 2:
        img = torch.flip(img, [-2])
    elif ops == 3:
        img = torch.flip(img, [-1, -2])
    elif ops == 4:
        img = torch.rot90(img, 1, [2, 3])
    elif ops == 5:
        img = torch.rot90(img, 3, [2, 3])
    else:
        pass
    return img

In [None]:
def GetDataLoader(batch=8, num_workers=2):
    files = test_df.iloc[:, 0].values
    dataset = InferDataset(files, augops=GetAugment(SIZE))
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

class InferDataset(Dataset):
    def __init__(self, files, augops):
        self.files = files
        self.augops = augops

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(os.path.join(img_dir, self.files[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        out = self.augops(force_apply=False, image=img)['image']
        out = out.transpose(2, 0, 1)
        return torch.from_numpy(out)

loader = GetDataLoader(batch=8)

In [None]:
modelgroups = [ ]
for i, groupdef in enumerate(modeldefs):
    print(f'--- Group{i} ---')
    models = [ ]
    for mdef in groupdef:
        mdl = GetModel(mdef['name'], mdef['pth']).to(device)
        models.append(mdl)
    modelgroups.append(models)

In [None]:
with torch.no_grad():
    gnums = len(modelgroups)
    preds = [ [] for _ in range(gnums) ] * gnums
    actfn = nn.Softmax(dim=1)
    for _, itr in enumerate(loader):
        x = itr
        b = x.shape[0]
        x = x.to(device)
        y = torch.zeros([gnums, b, NUM_CLASSES], device=device)
        for tta in range(TTA_ROUND):
            xi = TTA(x, tta)  # tta画像
            #print(tta, xi.shape)
            for g, mgi in enumerate(zip(modelgroups, modelttas)):  # モデルごとに処理するループ
                mgrp, mtta = mgi
                #print("mtta:", mtta)  # そのモデルのtta回数
                if tta < mtta:
                    ratio = 1.0 / (len(mgrp) * mtta)
                    for model in mgrp:
                        y[g, :, :] = y[g, :, :] + actfn(model(xi)) * ratio
                        #if tta == 0:
                        #    print("tta == 0", actfn(model(xi)))  # ttaなしでのy
                        #if tta == 1:
                        #    print("tta == 1", actfn(model(xi)))  # hflipでのy
                        #if tta == 2:
                        #    print("tta == 2", actfn(model(xi)))  # vflipでのy
        y = y.detach().cpu().numpy()
        for g in range(gnums):
            preds[g] += [ y[g, :, :] ]
for g in range(gnums):
    preds[g] = np.concatenate(preds[g], axis=0).reshape(-1, 5)
    test_preds_list.append(preds[g])

In [None]:
if COMMIT:
    for tp in test_preds_list:
        display(pd.DataFrame(tp, columns=name_mapping.values()))

## ensemble

In [None]:
### LBの値をアンサンブル重みにする場合
##
### oof=0.9050801514230967
### LB=0.905
###ens_weights = [0.8904051969902322, 0.8941440388839557, 0.8943309809786418, 0.8874608589989251, 0.8963873440201897]
##
### oof=0.9063
### LB=タイムアウト
###ens_weights = [0.8904051969902322, 0.893302799457868, 0.8941440388839557, 0.8943309809786418, 0.8874608589989251, 0.8847501986259756]
#
## oof=0.9057811842781698
## LB=
#ens_weights = [0.8904051969902322, 0.893302799457868, 0.8947983362153573, 0.897275318969949, 0.8997990372482124]
#
#
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i, pre in enumerate(test_preds_list):
#    test_preds_ensemble += (ens_weights[i] / sum(ens_weights)) * pre
#
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
# oofのcfmにもとづいた重み
# oof=0.9053605645651259
# LB=0.905
ens_weights = np.array(
    [
        [0.21396141, 0.20027513, 0.20551536, 0.19993096, 0.19595167],  # tf_efficientnet_b4_ns
        [0.20459705, 0.19993122, 0.20834644, 0.20013494, 0.20400593],  # byol_seresnext50_32x4d
        [0.1753689 , 0.1870916 , 0.18139876, 0.1981736 , 0.18853328],  # vit_b16_224_fold10
        [0.2122588 , 0.20382896, 0.20069204, 0.20110776, 0.20379398],  # 22019613_efficientnet-b4.pkl
        [0.19381385, 0.20887309, 0.20404739, 0.20065273, 0.20771513],  # 22019632_timm-resnest101e.pkl
    ]
)
test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
for i,pre in enumerate(test_preds_list):
    test_preds_ensemble += ens_weights[i] * pre
display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## oofのcfmにもとづいた重み  20210212_objective2
## oof=0.9055
## LB=
#ens_weights = np.array(
#    [
#        [0.21269394, 0.20043598, 0.20435825, 0.2000628, 0.19597244],   # sub_tf_efficientnet_b4_ns
#        [0.20338505, 0.20009179, 0.20717339, 0.20026692, 0.20402756],  # sub_byol_seresnext50_32x4d
#        [0.17433004, 0.18724185, 0.18037744, 0.19830429, 0.18855326],  # sub_vit_b16_224_fold10
#        [0.20056417, 0.2077788, 0.20300282, 0.20056524, 0.20222576],   # 22019631_se_resnext101.pkl
#        [0.2090268, 0.20445158, 0.2050881, 0.20080075, 0.20922099],    # ex02C_efficientnet-b4.pkl
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## oofのcfmにもとづいた重み  20210212_objective3
## oof_tta=0.9048
## LB=0.904
#ens_weights = np.array(
#    [
#        [0.2059718, 0.19812068, 0.20305518, 0.1998686, 0.19315833],    # tf_efficientnet_b4_ns_tta
#        [0.19159524, 0.20038492, 0.20388652, 0.19919596, 0.20556015],  # resnest101e_cleanlab_noise_cutmix_tta
#        [0.18495991, 0.19551681, 0.18819495, 0.20066639, 0.1988425],   # vit_b16_224_fold10_tta
#        [0.20873652, 0.20298879, 0.19910631, 0.20047867, 0.19832575],  # 22019613_tta3_efficientnet
#        [0.20873652, 0.20298879, 0.20575704, 0.19979039, 0.20411327],  # ex02C_tta5_efficientnet
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## oofのcfmにもとづいた重み  20210212_objective5
## oof_tta=0.90578
## LB=
#ens_weights = np.array(
#    [
#        [0.21317501, 0.19585202, 0.20264682, 0.19972413, 0.1907759],   # tf_efficientnet_b4_ns_tta
#        [0.19790783, 0.19932735, 0.20212986, 0.19942631, 0.2042922],   # resnest101e_cleanlab_noise_cutmix_tta
#        [0.1931015, 0.20426009, 0.20119934, 0.20044516, 0.20222864],   # 22019632_tta3_timm
#        [0.20949958, 0.19977578, 0.20337055, 0.20046083, 0.20367313],  # ex02C_tta5_efficientnet
#        [0.18631609, 0.20078475, 0.19065343, 0.19994357, 0.19903013],  # vit_base_patch16_384_TTA
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## oofのcfmにもとづいた重み  20210212_objective6
## oof_tta=0.90596
## LB=
#ens_weights = np.array(
#    [
#        [0.19648426, 0.19977427, 0.20473756, 0.19918661, 0.20534792],  # resnest101e_cleanlab_noise_cutmix_tta
#        [0.21122767, 0.19751693, 0.20390274, 0.19985922, 0.19295891],  # tf_efficientnet_b4_ns_tta
#        [0.18627729, 0.19525959, 0.1829281, 0.20114187, 0.19574644],   # vit_base_patch32_384
#        [0.19364899, 0.20564334, 0.20306793, 0.20003128, 0.20235391],  # 22019632_timm
#        [0.21236178, 0.20180587, 0.20536366, 0.19978101, 0.20359281],  # ex02C_tta3_efficientnet
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## oofのcfmにもとづいた重み  20210212_objective7
## oof_tta=0.90578
## LB=
#ens_weights = np.array(
#    [
#        [0.20922478, 0.19727891, 0.19879455, 0.19959025, 0.20360584],  # resnest101e_tta
#        [0.20283412, 0.20147392, 0.20721189, 0.19943385, 0.19096467],  # tf_efficientnet_b4_ns_fold3_tta
#        [0.18588497, 0.19580499, 0.18819495, 0.20062245, 0.19935758],  # vit_b16_224_fold10_tta
#        [0.1961656, 0.20340136, 0.2013925, 0.20034094, 0.20153352],    # 22019720_tta3_efficientnet
#        [0.20589053, 0.20204082, 0.20440611, 0.20001251, 0.20453839],  # ex02C_efficientnet
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
## optunaで最適化した重み
## oof = 0.9065289526569145  20210212_objective7
## LB=0.902
#ens_weights = np.array(   
#    [
#        [0.15752952526994954,0.22514403632959823,0.18918843864327994,0.2479556449888305,0.24244176464181127], # 'resnest101e_tta'
#        [0.23845076912265822,0.1925789461846951,0.1591652375899216,0.22242568481767636,0.2139059538229616], # 'tf_efficientnet_b4_ns_fold3_tta'
#        [0.21851624337733444,0.16902723642770212,0.206556619321817,0.21702286601089407,0.19426374746503489], # 'vit_b16_224_fold10_tta'
#        [0.21952535985009275,0.19421036061074562,0.19408755305253653,0.2073030212336284,0.1825570805293047], # '22019720_tta3_efficientnet-b4
#        [0.1999030377835539,0.20443143909871822,0.22043869750054165,0.1686958731252343,0.1847032538931056],  # ex02C_efficientnet-b4.pkl
#    ]
#)
#test_preds_ensemble = np.zeros((test_df.shape[0], n_classes))
#for i,pre in enumerate(test_preds_list):
#    test_preds_ensemble += ens_weights[i] * pre
#display(pd.DataFrame(test_preds_ensemble, columns=name_mapping.values()))

In [None]:
test_df['label'] = test_preds_ensemble.argmax(1)
#test_df.to_csv('submission.csv', index=False)
display(test_df)

## stacking

In [None]:
import sys
sys.path.append("../input/cassava-code")
from lightning_cassava_stacking import (
    train_stacking,
    pred_stacking,
    StackingConfig,
)

In [None]:
# ノイズとおぼしきサンプルtrainから除くか
from params import base_data

#df = pd.read_csv(f"{path}/train.csv")
#noise_image_id = base_data.noise_image_id
#noise_idx = list(df[df["image_id"].isin(noise_image_id)].index)

noise_idx = None

#df = pd.read_csv(f"{path}/train.csv")
#df["noise_image_id"] = df.iloc[noise_idx]["image_id"]
#not_noise_image_id = df[df["noise_image_id"] == False].index.values
#print("len(noise_image_id):", len(noise_image_id))
#print("len(not_noise_image_id):", len(not_noise_image_id))
#
#
#def del_df_noise_image_id(preds, y, not_noise_image_id):
#    preds = [p[not_noise_image_id] for p in preds]
#    y = y[not_noise_image_id]
#    return preds, y

In [None]:
# cnmn2d params
StackingCFG = StackingConfig()
StackingCFG.device = device
StackingCFG.num_workers = num_workers
StackingCFG.n_classes = n_classes
StackingCFG.arch = "cnmn2d"

# 20210212_objective13
StackingCFG.cnn2d_params = {'kwargs_head': {'drop_rate': 0.9, 
                                            'n_features_list': [-1, 5], 
                                            'use_bn': False, 
                                            'use_tail_as_out': True, 
                                            'use_wn': False},
                            'n_channels_list': [1, 8],
                            'n_classes': n_classes, 
                            'n_models': 5, 
                            'use_bias': False,
                           }
StackingCFG.weight_decay = 1e-05
StackingCFG.smoothing = 0.2
StackingCFG.t1 = 0.8
StackingCFG.t2 = 1.3
StackingCFG.gauss_scale = 0.24
StackingCFG.cutmix_p = 0.0
StackingCFG.alpha = 0.2
StackingCFG.seeds = list(range(2))
s_type = 1
gauss_scale = 0.0
is_noise_only_test = False
is_del_df_noise_image_id = False
pseudo_th = 0.0

if gauss_scale > 0.0:
    set_seed(seed=42)
              
#StackingCFG.max_epochs = 5

print(StackingCFG.__dict__)

In [None]:
%%time
# cnmn2d semi supervised learning
StackingCFG.out_dir = StackingCFG.arch
os.makedirs(StackingCFG.out_dir, exist_ok=True)


# load oof
preds = []

# --------------------------------------------------------------------------------------------
# ---- anonamename ----
tta_dir = "../input/cassava-efficientnetwithpytorchlightning/kaggle_upload_oof_tta"

pkl = f"{tta_dir}/tf_efficientnet_b4_ns_tta.pkl"
pred = pickle.load(open(pkl, "rb"))
preds.append(pred.values)

pkl = f"{tta_dir}/byol_seresnext50_32x4d_cutmix_labelsmooth_half_tta.pkl"
pred = pickle.load(open(pkl, "rb"))
preds.append(pred.values)

pkl = f"{tta_dir}/vit_b16_224_fold10_tta.pkl"
pred = pickle.load(open(pkl, "rb"))
preds.append(pred.values)

#pkl = f"{tta_dir}/resnest101e_cleanlab_noise_cutmix_tta.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{tta_dir}/resnest101e_tta.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{tta_dir}/tf_efficientnet_b4_ns_fold3_tta.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{tta_dir}/vit_base_patch32_384.fit_tta.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

# ---- SiNpcw ----
m_dir = "../input/cassavapkl"

#pkl = f"{m_dir}/22019613_efficientnet-b4.pkl"
pkl = f"{m_dir}/22019613_tta3_efficientnet-b4.pkl"
pred = pickle.load(open(pkl, "rb"))
preds.append(pred.values)

#pkl = f"{m_dir}/22019632_timm-resnest101e.pkl"
pkl = f"{m_dir}/22019632_tta3_timm-resnest101e.pkl"
pred = pickle.load(open(pkl, "rb"))
preds.append(pred.values)

#pkl = f"{m_dir}/22019631_tta3_se_resnext101.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{m_dir}/22019717_tta3_timm-resnest101e.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{m_dir}/22019720_tta3_efficientnet-b4.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)

#pkl = f"{m_dir}/ex02C_efficientnet-b4.pkl"
#pkl = f"{m_dir}/ex02C_tta3_efficientnet-b4.pkl"
#pkl = f"{m_dir}/ex02C_tta5_efficientnet-b4.pkl"
#pred = pickle.load(open(pkl, "rb"))
#preds.append(pred.values)
# --------------------------------------------------------------------------------------------


df = pd.read_csv(f"{path}/train.csv")
y = df["label"].values


# 過去コンペのデータは除く
preds = [p[:21397] for p in preds]


if is_del_df_noise_image_id:
    # --- LB0.905の5モデルの内3モデルすべて同じ予測で間違ってるレコード削除 ---    
    preds, y = del_df_noise_image_id(preds, y, not_noise_image_id)

    
if pseudo_th > 0.0 and len(test_df) > 1:
    # --- スコア pseudo_th 以上のサンプルだけ pseudo label に採用する ---
    test_df_orig = test_df.copy()
    test_preds_list_orig = test_preds_list[:]
    test_df['logit'] = test_preds_ensemble.max(1)
    test_df = test_df[test_df['logit'] > pseudo_th]
    test_y = test_df['label'].values
    test_preds_list = [t[test_df['label'].index.values] for t in test_preds_list]
else:
    test_y = test_df['label'].values
    

if s_type == 0:
    # --- pseudo label使わない ---
    x = np.array(preds).transpose(1,2,0)
    x = x.reshape(len(x), 1, StackingCFG.n_classes, len(preds))
    
    oof, oof_loss = train_stacking(x, y, StackingCFG, is_check_model=False, 
                                   noise_idx=noise_idx)

elif s_type == 1:
    # --- pseudo labelそのまま突っ込む場合 ---
    preds = [np.vstack((preds[ii], test_preds_list[ii])) for ii in range(len(preds))]
    y = np.concatenate([y, test_y])

    x = np.array(preds).transpose(1,2,0)
    x = x.reshape(len(x), 1, StackingCFG.n_classes, len(preds))
    
    if gauss_scale > 0.0:
        # oofとtestにガウスノイズ加算
        # https://www.kaggle.com/c/stanford-covid-vaccine/discussion/189709
        x += np.random.normal(0.0, scale=gauss_scale, size=x.shape)
    
    oof, oof_loss = train_stacking(x, y, StackingCFG, is_check_model=False, 
                                   noise_idx=noise_idx)

elif s_type == 2:
    # --- pseudo labelのデータvalidationには入れない ---
    test_x = np.array(test_preds_list).transpose(1,2,0)
    test_x = test_x.reshape(len(test_x), 1, StackingCFG.n_classes, len(preds))
    
    x = np.array(preds).transpose(1,2,0)
    x = x.reshape(len(x), 1, StackingCFG.n_classes, len(preds))
    
    if gauss_scale > 0.0:
        if is_noise_only_test:
            # testにガウスノイズ加算
            test_x += np.random.normal(0.0, scale=gauss_scale, size=test_x.shape)
        else:
            # oofとtestにガウスノイズ加算
            test_x += np.random.normal(0.0, scale=gauss_scale, size=test_x.shape)
            x += np.random.normal(0.0, scale=gauss_scale, size=x.shape)
            
    oof, oof_loss = train_stacking(x, y, StackingCFG, is_check_model=False, 
                                   add_train_x=test_x, add_train_y=test_y, 
                                   noise_idx=noise_idx)
    

if pseudo_th > 0.0 and len(test_df) > 1:
    # 抜いたtestデータ戻す
    test_df = test_df_orig
    test_preds_list = test_preds_list_orig
    
    
## predict check
#pred_cnn2d = pred_stacking(x, StackingCFG)
#accuracy_score(y, pred_cnn2d.argmax(1))

In [None]:
oof

In [None]:
# cnmn2d predict
x = np.array(test_preds_list).transpose(1,2,0)
x = x.reshape(len(x), 1, StackingCFG.n_classes, len(test_preds_list))
print(x.shape)

if gauss_scale > 0.0:
    x += np.random.normal(0.0, scale=gauss_scale, size=x.shape)

pred_cnn2d = pred_stacking(x, StackingCFG)
display(pd.DataFrame(pred_cnn2d, columns=name_mapping.values()))

test_df['label'] = pred_cnn2d.argmax(1)
test_df[["image_id", "label"]].to_csv(f'submission.csv', index=False)
display(test_df)

In [None]:
#import os, shutil
#if os.path.exists("lightning_logs/"):
#    shutil.rmtree("lightning_logs/")
#    os.mkdir("lightning_logs/")