# Cassava classification

### Import the required libraries

In [171]:
from collections import defaultdict
import copy
import random
import numpy as np
import os
import shutil
from urllib.request import urlretrieve
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils import resample
cudnn.benchmark = True
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import timm
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from utils import Mixup, RandAugment, RAdam
from PIL import Image
SEED = 42

In [172]:
def seed_everything(SEED):
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(SEED)

In [173]:
torch.__version__
os.environ['CUDA_VISIBLE_DEVICES'] ="0"

### Set the root directory dataset

In [174]:
root = os.path.join(os.environ["HOME"], "Workspace/datasets/taiyoyuden/cassava")

### Split files from the dataset into the train and validation sets

Some files in the dataset are broken, so we will use only those image files that OpenCV could load correctly. We will use 20000 images for training, 4936 images for validation, and 10 images for testing.

In [175]:
os.listdir(root)

['test_images',
 'label_num_to_disease_map.json',
 'external',
 'sample_submission.csv',
 'train_images',
 'label_num_to_disease_map.json.save',
 'train.csv']

### Define a function to visualize images and their labels

Let's define a function that will take a list of images' file paths and their labels and visualize them in a grid. Correct labels are colored green, and incorrectly predicted labels are colored red.

In [176]:
train = pd.read_csv(f'{root}/train.csv')
train_external = pd.read_csv(f'{root}/external/train_external.csv')
test_external = pd.read_csv(f'{root}/external/test_external.csv')
test_external_pseudo = pd.read_csv(f'{root}/external/test_external_pseudo.csv')
test = pd.read_csv(f'{root}/sample_submission.csv')
label_map = pd.read_json(f'{root}/label_num_to_disease_map.json', 
                         orient='index')
display(train.head())
display(test.head())
display(label_map)

Unnamed: 0,image_id,label
0,1000015157.jpg,0
1,1000201771.jpg,3
2,100042118.jpg,1
3,1000723321.jpg,1
4,1000812911.jpg,3


Unnamed: 0,image_id,label
0,2216849948.jpg,4
1,1000015157.jpg,0
2,1000201771.jpg,3


Unnamed: 0,0
0,Cassava Bacterial Blight (CBB)
1,Cassava Brown Streak Disease (CBSD)
2,Cassava Green Mottle (CGM)
3,Cassava Mosaic Disease (CMD)
4,Healthy


In [177]:
label_map.iloc[1].values
np.random.randint(50, 10000, 50)

array([7320,  910, 5440, 5241, 5784, 6315,  516, 4476, 5628, 8372, 1735,
        819, 6999, 2483, 5361, 5101, 6470, 1234, 4605, 3435, 6446, 8716,
       9324, 2608, 7899, 2097, 2797, 9217,  239, 2784, 3055, 4708, 1949,
       7784, 1317, 1578, 3606, 3940, 8888, 5443, 8842, 8483, 7563, 2662,
       7091, 9605, 6285, 5536, 7149, 9720])

In [178]:
def visualize_input_image_grid(filepaths, image_name, labels, cols=4):
    rows = 5
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(25, 15))
    for i, index in enumerate(np.random.randint(0, len(image_name), 20)):
        name = image_name.iloc[index]['image_id']
        label =  image_name.iloc[index]['label']
        image = cv2.imread(f'{filepaths}/train_images/{name}')
        if (i == 0): 
            print(image.shape)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.ravel()[i].imshow(image)
        ax.ravel()[i].set_title(label, color='GREEN')
        ax.ravel()[i].set_axis_off()
    plt.tight_layout()
    plt.show()

### Define the parameters of the whole process 

In [179]:
models_name = ["resnest26d","resnest50d","tf_efficientnet_b3_ns", "skresnet34" ,"cspresnet50", "vit_base_patch16_384"]
WEIGHTS = [
    "weights/resnest26d_fold0_best_epoch_28_final_1st.pth", #0
    "weights/resnest26d_fold0_best_epoch_13_final_mixup.pth",
    "weights/resnest26d_fold1_best_epoch_7_finale_ex_hnm.pth",
    "weights/resnest26d_fold2_best_epoch_3_final_hnm.pth",
    "weights/resnest26d_fold4_best_epoch_29_1st.pth",
    "weights/resnest26d_fold4_best_epoch_26_mix.pth", #5
    "weights/resnest26d_fold4_best_epoch_12_cutmix.pth", 
    "weights/resnest26d_fold4_best_epoch_3_external.pth",
    "weights/resnest26d_fold4_best_epoch_21_final_512.pth",
    "weights/tf_efficientnet_b3_ns_fold1_best_epoch_19_external.pth",
    "weights/tf_efficientnet_b3_ns_fold1_best_epoch_26_512.pth", #10
    "weights/tf_efficientnet_b3_ns_fold1_best_epoch_1_final_512.pth", 
    "weights/resnest50d_fold1_best_epoch_95_final_1st.pth",
    "weights/resnest50d_fold1_best_epoch_9_final_512.pth",
    "weights/resnest50d_fold1_best_epoch_82.pth", # test balance data (0.86)
    "weights/cspresnet50_fold1_best_epoch_24_final.pth", #15
        
]
# Single model
model_index = 1
ckpt_index = 12

# kfold
fold_model_index = 1
fold_ckpt_index = [12,13,14]
fold_ckpt_weight = [1,1,1]

# Ensemble
ensemble_models_name = ["resnest26d" ,"tf_efficientnet_b3_ns", "resnest50d", "cspresnet50"]
ensemble_ckpt_index = [2, 11, 12, 15]
ensemble_ckpt_weight = [1, 1, 1, 1]

params = {
    "visualize": False,
    "fold": 1,
    "train_external": True,
    "test_external": True,
    "load_pretrained": True,
    "resume": False,
    "image_size": 512,
    "num_classes": 5,
    "model": models_name[model_index],
    "device": "cuda",
    "lr": 5e-5,
    "lr_min":1e-6,
    "batch_size": 8,
    "num_workers": 8,
    "epochs": 100,
    "gradient_accumulation_steps": 8,
    "drop_block": 0.3,
    "drop_rate": 0.3,
    "mix_up": True,
    "cutmix":True,
    "rand_aug": False,
    "local_rank":0,
    "distributed": False,
    "hard_negative_sample": False,
    "tta": True,
    "crops_tta":False,
    "train_phase":False,
    "balance_data":False,
    "kfold_pred":True,
    "ensemble": True
}

### Define dataset with KFolds strategy

In [180]:
folds = train.copy()
Fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['label'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
###
fold = params["fold"]
train_idx = folds[folds['fold'] != fold].index
val_idx = folds[folds['fold'] == fold].index

train_folds = folds.loc[train_idx].reset_index(drop=True)
val_folds = folds.loc[val_idx].reset_index(drop=True)

In [181]:
# Dataset
class TrainDataset(Dataset):
    def __init__(self, df, transform=None, mosaic_mix = False):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
        self.transform = transform
        self.mosaic_mix = mosaic_mix
        self.rand_aug_fn = RandAugment()
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{root}/train_images/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = torch.tensor(self.labels[idx]).long()
        if params["rand_aug"]:
            image = np.array(self.rand_aug_fn(Image.fromarray(image)))
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label, file_name
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None, valid_test=False, fcrops=False):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        self.valid_test = valid_test
        self.fcrops = fcrops
        if self.valid_test:
            self.labels = df['label'].values  
        else:
            assert ValueError("Test data does not have annotation, plz check!")
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        if self.valid_test:
            file_path = f'{root}/train_images/{file_name}'
            #file_path = f'{root}/external/extraimages/{file_name}'
        else:
            file_path = f'{root}/test_images/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if isinstance(self.transform, list):
            outputs = {'images':[],
                       'labels':[],
                       'image_ids':[]}
            if self.fcrops:
                for trans in self.transform:
                    image_aug = Image.open(file_path)
                    image_aug = trans(image_aug)
                    outputs["images"].append(image_aug)
                    del image_aug
            else:
                for trans in self.transform:
                    augmented = trans(image=image)
                    image_aug = augmented['image']
                    outputs["images"].append(image_aug)
                    del image_aug

            if self.valid_test:
                label = torch.tensor(self.labels[idx]).long()
                outputs['labels'] = len(self.transform)*[label]
                outputs['image_ids'].append(file_name)
                
            else:
                outputs['labels'] = len(self.transform)*[-1]
                
            return outputs
        else:
            augmented = self.transform(image=image)
            image = augmented['image'] 
        return image

### Use Albumentations to define transformation functions for the train and validation datasets

In [182]:
train_transform = A.Compose(
    [
        A.RandomResizedCrop(height=params["image_size"], width=params["image_size"], p=1),
        A.OneOf([
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),], p=1.
        ),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.IAAAffine(rotate=0.2, shear=0.2,p=0.5),
        A.CoarseDropout(max_holes=20, max_height=int(params["image_size"]/15), max_width=int(params["image_size"]/15), p=0.5),
#         A.IAAAdditiveGaussianNoise(p=1.),
        A.MedianBlur(p=0.5),
        A.Equalize(p=0.2),
        A.GridDistortion(p=0.2),
#         A.RandomGridShuffle(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)
val_transform = A.Compose(
    [
        A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),
        A.Resize(params["image_size"],params["image_size"]),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)
if params["cutmix"]:
    mixup_fn = Mixup(mixup_alpha=1., cutmix_alpha=1., label_smoothing=0.1, num_classes=params["num_classes"])
else:
    mixup_fn = Mixup(mixup_alpha=1., label_smoothing=0.1, num_classes=params["num_classes"])

val_dataset = TrainDataset(val_folds, transform=val_transform)

In [183]:
transform_tta0 = A.Compose(
    [
     A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),    
     A.Resize(height=params["image_size"], width=params["image_size"], p=1),
     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),   
     ToTensorV2()
    ]
)

transform_tta1 = A.Compose(
    [
     A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),
     A.Resize(height=params["image_size"], width=params["image_size"], p=1),
     A.HorizontalFlip(p=1.),
     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),   
     ToTensorV2()
    ]
)
transform_tta2 = A.Compose(
    [
     A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),
     A.Resize(height=params["image_size"], width=params["image_size"], p=1),
     A.VerticalFlip(p=1.),
     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
     ToTensorV2(),
    ]
)
transform_tta3 = A.Compose(
    [
     A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),
     A.Resize(height=params["image_size"], width=params["image_size"], p=1),
     A.RandomRotate90(p=1.),
     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),     
     ToTensorV2(),
    ]
)
transform_tta4 = A.Compose(
    [
     A.CenterCrop(height=params["image_size"], width=params["image_size"], p=1),
     A.Resize(height=params["image_size"], width=params["image_size"], p=1),
     A.RandomRotate90(p=1.),
     A.RandomRotate90(p=1.),
     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),     
     ToTensorV2(),
    ]
)

##### Test TTA with Five Crops
transform_crop_tta0 = T.Compose([
                T.FiveCrop(params["image_size"]),
                T.Lambda(lambda crops: ([T.ToTensor()(crop) for crop in crops])),
                T.Lambda(lambda norms: torch.stack([T.Normalize(mean=[0.5], std=[0.5])(norm) for norm in norms]))
        ]
)
    
transform_crop_tta1 = T.Compose([
                T.FiveCrop(params["image_size"]),
                T.Lambda(lambda crops: ([T.ToTensor()(crop) for crop in crops])),
                T.Lambda(lambda flips: ([T.RandomHorizontalFlip(p=1.)(flip) for flip in flips])),
                T.Lambda(lambda norms: torch.stack([T.Normalize(mean=[0.5], std=[0.5])(norm) for norm in norms]))
    ]
)
transform_crop_tta2 = T.Compose([
                T.FiveCrop(params["image_size"]),
                T.Lambda(lambda crops: ([T.ToTensor()(crop) for crop in crops])),
                T.Lambda(lambda flips: ([T.RandomVerticalFlip(p=1.)(flip) for flip in flips])),
                T.Lambda(lambda norms: torch.stack([T.Normalize(mean=[0.5], std=[0.5])(norm) for norm in norms]))
    ]
)
test_transform_tta = [transform_tta0, transform_tta1, transform_tta2, transform_tta3, transform_tta3, transform_tta4]
test_transform_tta_crops = [transform_crop_tta0, transform_crop_tta1, transform_crop_tta2]

Also let's define a function that takes a dataset and visualizes different augmentations applied to the same image.

In [184]:
def calculate_accuracy(output, target):
#     return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
    if params["mix_up"]:
        output = torch.argmax(torch.softmax(output, dim=1), dim=1)
        return accuracy_score(output.cpu(), target.argmax(1).cpu())
    
    output = torch.softmax(output, dim=1)
    return accuracy_score(output.argmax(1).cpu(), target.cpu())

In [185]:
class MetricMonitor:
    def __init__(self, float_precision=3):
        self.float_precision = float_precision
        self.reset()
        self.curr_acc = 0.
    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]
        self.curr_acc = metric["avg"]
    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

In [224]:
def declare_model(name, index=None):
    model = timm.create_model(name,
            pretrained=False,
            num_classes=params["num_classes"],
            drop_rate=params["drop_rate"])
    model = model.to(params["device"])
    
    if params["distributed"]:
        assert ValueError("No need to implement in a single machine")
    else:
        model = torch.nn.DataParallel(model) 
        
    if params["load_pretrained"]:
        if index == None: 
            return model
        state_dict = torch.load(WEIGHTS[index])
        print(f"Load pretrained model: {name} ",state_dict["preds"])
        model.load_state_dict(state_dict["model"])
        best_acc = state_dict["preds"]   
    return model

model = declare_model(params["model"], ckpt_index)

weights/resnest50d_fold1_best_epoch_95_final_1st.pth resnest50d
Load pretrained model: resnest50d  0.8967


In [222]:
# n_features = model.fc.in_features
# model.fc = nn.Linear(n_features, params['num_classes'])

### Create all required objects and functions for training and validation

In [188]:
val_criterion = nn.CrossEntropyLoss().to(params["device"])
# criterion = nn.CrossEntropyLoss().to(params["device"])
criterion = LabelSmoothingCrossEntropy().to(params["device"])
if params["mix_up"]:
    criterion = SoftTargetCrossEntropy().to(params["device"])         

--------------------------------------------------------------------------------
## Make validation prediction w/TTA on validation and test set

In [189]:
if params["tta"]:
    val_pred_dataset = TestDataset(val_folds, transform=test_transform_tta, valid_test=True)
    test_pred_dataset = TestDataset(test, transform=test_transform_tta)
else:
    val_pred_dataset = TestDataset(val_folds, transform=val_transform, valid_test=True)
    test_pred_dataset = TestDataset(test, transform=val_transform)

val_pred_dataset_crops = TestDataset(val_folds, transform=test_transform_tta_crops,valid_test=True, fcrops=True)


In [190]:
val_pred_loader = DataLoader(
    val_pred_dataset, batch_size=params['batch_size'], shuffle=False, num_workers=2, pin_memory=True,
)
val_pred_loader_crop = DataLoader(
    val_pred_dataset_crops, batch_size=params['batch_size'], shuffle=False, num_workers=2, pin_memory=True,
)
test_pred_loader = DataLoader(
    test_pred_dataset, batch_size=params['batch_size'], shuffle=False, num_workers=2, pin_memory=True,
)

In [191]:
def visualize_tta(data, pred):     
    unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    if params["tta"]:
        figure, ax = plt.subplots(nrows=1, ncols=num_tta, figsize=(12, 6))
        for i, image in enumerate(data["images"]):
            image = (unorm(image[0]).cpu().numpy()*255).astype(int)
            ax.ravel()[i].imshow(image.transpose(2,1,0))
            ax.ravel()[i].set_title(str(pred), color='GREEN')
            ax.ravel()[i].set_axis_off()
        plt.tight_layout()
        plt.show() 
    else:
        image = (unorm(data).cpu().numpy()*255).astype(int)
        imgplot = plt.imshow(image[0].transpose(2,1,0))
        imgplot = plt.title(str(pred), color='GREEN')
        plt.show()

In [192]:
def validate(val_loader, model, criterion, epoch, params, fold, best_acc):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, target, _) in enumerate(stream, start=1):
            images = images.to(params["device"], non_blocking=True)
            target = target.to(params["device"], non_blocking=True)#.view(-1,params['batch_size'])
            output = model(images)
            loss = val_criterion(output, target)
            output = torch.softmax(output, dim = 1)
            
            accuracy = accuracy_score(output.argmax(1).cpu(), target.cpu())

            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )           
            metric_monitor.update("Loss", loss.item())
            metric_monitor.update("Accuracy", accuracy)
        #to save weight
        if (metric_monitor.curr_acc > best_acc): # or epoch == params["epochs"]:
            print(f"Save best weight at acc {round(metric_monitor.curr_acc,4)}, epoch: {epoch}")
            torch.save({'model': model.state_dict(), 
                'preds': round(metric_monitor.curr_acc,4)},
                 f'weights/{params["model"]}_fold{fold}_best_epoch_{epoch}.pth')

            best_acc = metric_monitor.curr_acc
    return best_acc

In [193]:
num_tta = len(test_transform_tta)
incorrect_pred_list = {'image_id':[],
                       'label':[],
                       'pred':[]}
def tta_validate(loader, model, params, valid_test=False):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(loader)
    
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["visualize"]:
                visualize_tta(data, pred)
            tta_output = []   
            for i, image in enumerate(data["images"]):
                tta_output.append(model(image))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            for i in  range(len(data["labels"][0])):
                if data["labels"][0][i] != output.argmax(1).cpu()[i]:
                    incorrect_pred_list['label'].append(data["labels"][0][i])
                    incorrect_pred_list['pred'].append(output.argmax(1).cpu()[i])
                    incorrect_pred_list['label'].append(data["image_ids"][0][i])                        
            accuracy = accuracy_score(output.argmax(1).cpu(), data["labels"][0].cpu())
                
            metric_monitor.update("Accuracy", accuracy)            
            stream.set_description(
                "TTA Validation. {metric_monitor}".format(metric_monitor=metric_monitor)
            )   
num_tta_crop = len(test_transform_tta_crops)
if params["tta"]:
    tta_validate(val_pred_loader, model, params, valid_test=True)


TTA Validation. Accuracy: 0.902: 100%|██████████| 535/535 [03:26<00:00,  2.59it/s]


In [194]:
def tta_crops_validate(loader, model, params, valid_test=False):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(loader)
    
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["visualize"]:
                visualize_tta(data, pred)
            tta_output = []   
            for i, image in enumerate(data["images"]):
                batch, ncrops, c, h, w = image.size()
                image = image.view(-1, c, h, w)
                output = model(image)
                tta_output.append(output.view(batch, ncrops, -1).mean(1))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            accuracy = accuracy_score(output.argmax(1).cpu(), data["labels"][0].cpu())
                
            metric_monitor.update("Accuracy", accuracy)            
            stream.set_description(
                "Crops TTA Validation. {metric_monitor}".format(metric_monitor=metric_monitor)
            )   
## Re test the validation accuracy with TTA
if params["crops_tta"]:
    tta_crops_validate(val_pred_loader_crop, model, params, valid_test=True)

In [195]:
# CV w/wo fold predict on validation set
def kfold_tta_validate(loader, params, valid_test=False):
    model = declare_model(models_name[fold_model_index])
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(loader)
    models = []
    for ckpt, weight in zip(fold_ckpt_index, fold_ckpt_weight):
        state_dict = torch.load(WEIGHTS[ckpt])
        fold_acc = state_dict["preds"]
        print(f"Load pretrained model: {WEIGHTS[ckpt]} with acc: {fold_acc}")
        model.load_state_dict(state_dict["model"])
        models.append(model)
        
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["visualize"]:
                visualize_tta(data, pred)
            tta_output = []   
            for i, image in enumerate(data["images"]):
                kfout = [model(image) for model in models]
                tta_output.append(torch.stack(kfout, dim=0).mean(dim=0))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            accuracy = accuracy_score(output.argmax(1).cpu(), data["labels"][0].cpu())
                
            metric_monitor.update("Accuracy", accuracy)            
            stream.set_description(
                "Folds TTA Validation. {metric_monitor}".format(metric_monitor=metric_monitor)
            ) 
if params["tta"] and params["fold"]:
    kfold_tta_validate(val_pred_loader, params, valid_test=True)

  0%|          | 0/535 [00:00<?, ?it/s]

Load pretrained model: weights/resnest50d_fold1_best_epoch_95_final_1st.pth with acc: 0.8967
Load pretrained model: weights/resnest50d_fold1_best_epoch_9_final_512.pth with acc: 0.8998
Load pretrained model: weights/resnest50d_fold1_best_epoch_82.pth with acc: 0.8666


Folds TTA Validation. Accuracy: 0.892: 100%|██████████| 535/535 [10:44<00:00,  1.20s/it]


In [228]:
# Ensemble CV w/ fold predict on validation set
def ensemble_w_kfold_tta_validate(loader, models, params, valid_test=False):
    metric_monitor = MetricMonitor()
    stream = tqdm(loader)
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["visualize"]:
                visualize_tta(data, pred)
            tta_output = []   
            for i, image in enumerate(data["images"]):
                kfout = [model(image) for model in models]
                tta_output.append(torch.stack(kfout, dim=0).mean(dim=0))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            accuracy = accuracy_score(output.argmax(1).cpu(), data["labels"][0].cpu())
                
            metric_monitor.update("Accuracy", accuracy)            
            stream.set_description(
                "Ensemble Validation. {metric_monitor}".format(metric_monitor=metric_monitor)
            ) 
if params["ensemble"] and params["tta"] and params["fold"]:
    models = []
    print(f"Ensemble below models: {ensemble_models_name}")
    for name, ckpt, weight in zip(ensemble_models_name, ensemble_ckpt_index, ensemble_ckpt_weight):
        print(ckpt)
        m = declare_model(name, ckpt)  
        models.append(m)
        del m
    ensemble_w_kfold_tta_validate(val_pred_loader, models, params, valid_test=True)

Ensemble below models: ['resnest26d', 'tf_efficientnet_b3_ns', 'resnest50d', 'cspresnet50']
2
weights/resnest26d_fold1_best_epoch_7_finale_ex_hnm.pth resnest26d
Load pretrained model: resnest26d  0.8997
11
weights/tf_efficientnet_b3_ns_fold1_best_epoch_1_final_512.pth tf_efficientnet_b3_ns
Load pretrained model: tf_efficientnet_b3_ns  0.897
12
weights/resnest50d_fold1_best_epoch_95_final_1st.pth resnest50d
Load pretrained model: resnest50d  0.8967
15


  0%|          | 0/535 [00:00<?, ?it/s]

weights/cspresnet50_fold1_best_epoch_24_final.pth cspresnet50
Load pretrained model: cspresnet50  0.8883


Ensemble Validation. Accuracy: 0.896: 100%|██████████| 535/535 [11:04<00:00,  1.24s/it]


------------------------
## Make Test prediction w/wo Kfold check

In [229]:
num_tta = len(test_transform_tta)
def predict(loader, model, params, valid_test=False):
    outputs = []
    preds = []
    model.eval()
    stream = tqdm(loader)
    
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["tta"]:
                ## TTA output
                tta_output = []   
                for i, image in enumerate(data["images"]):
                    tta_output.append(model(image))
                tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
                output = torch.softmax(tta_output, dim=1)
                outputs.append(output)
            else:
                data = data.to(params["device"], non_blocking=True)
                output = model(data)
                output = torch.softmax(output, dim = 1)
                outputs.append(output)
                
            pred = torch.argmax(output, dim=1).cpu().numpy()
            preds.extend(pred)
            ## For visualize the TTA 
            if params["visualize"]:
                visualize_tta(data, pred)
    if params["kfold_pred"]:
        return torch.stack(outputs, dim=0)
    else:
        return preds

In [232]:
# CV w/wo fold predict
if params["kfold_pred"]:
    fold_preds = []
    for ckpt, weight in zip(fold_ckpt_index, fold_ckpt_weight):
        state_dict = torch.load(WEIGHTS[ckpt])
        fold_acc = state_dict["preds"]
        print(f"Load pretrained model: {weights[ckpt]} with acc: {fold_acc}")
        model.load_state_dict(state_dict["model"])
        best_acc = state_dict["preds"]
        out = predict(test_pred_loader, model, params, valid_test=True).mul(weight)
        fold_preds.append(out)
        del(out)
    final_out_fold = torch.softmax(torch.stack(fold_preds).mean(dim=0), dim=2).view(-1,5).argmax(1)
    final_out = final_out_fold.cpu().numpy()
else:
    final_out = predict(test_pred_loader, model, params, valid_test=True)
    
test['label'] = final_out 
test.to_csv('submission.csv', index=False)
test.head(n=10)

  0%|          | 0/1 [00:00<?, ?it/s]

Load pretrained model: weights/resnest50d_fold1_best_epoch_9_final_512.pth with acc: 0.8967


100%|██████████| 1/1 [00:00<00:00,  2.33it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Load pretrained model: weights/resnest50d_fold1_best_epoch_78.pth with acc: 0.8998


100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Load pretrained model: weights/cspresnet50_fold1_best_epoch_24_final.pth with acc: 0.8666


100%|██████████| 1/1 [00:00<00:00,  2.32it/s]


Unnamed: 0,image_id,label
0,2216849948.jpg,4
1,1000015157.jpg,2
2,1000201771.jpg,3


In [244]:
# CV w/wo fold predict
def ensemble_w_kfold_tta_pred(loader, models, params, valid_test=False):
    metric_monitor = MetricMonitor()
    stream = tqdm(loader)
    outputs = []
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            if params["visualize"]:
                visualize_tta(data, pred)
            tta_output = []   
            for i, image in enumerate(data["images"]):
                kfout = [model(image) for model in models]
                tta_output.append(torch.stack(kfout, dim=0).mean(dim=0))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            outputs.append(output.argmax(1).cpu().numpy())
    return outputs

if params["ensemble"]:
    models = []
    print(f"Ensemble below models: {ensemble_models_name}")
    for name, ckpt, weight in zip(ensemble_models_name, ensemble_ckpt_index, ensemble_ckpt_weight):
        m = declare_model(name, ckpt)  
        models.append(m)
        del m
    final_out_ensemble = ensemble_w_kfold_tta_pred(test_pred_loader, models, params, valid_test=True)[0]
    
test['label'] = final_out_ensemble 
test.to_csv('submission.csv', index=False)
test.head(n=10)

Ensemble below models: ['resnest26d', 'tf_efficientnet_b3_ns', 'resnest50d', 'cspresnet50']
weights/resnest26d_fold1_best_epoch_7_finale_ex_hnm.pth resnest26d
Load pretrained model: resnest26d  0.8997
weights/tf_efficientnet_b3_ns_fold1_best_epoch_1_final_512.pth tf_efficientnet_b3_ns
Load pretrained model: tf_efficientnet_b3_ns  0.897
weights/resnest50d_fold1_best_epoch_95_final_1st.pth resnest50d
Load pretrained model: resnest50d  0.8967


  0%|          | 0/1 [00:00<?, ?it/s]

weights/cspresnet50_fold1_best_epoch_24_final.pth cspresnet50
Load pretrained model: cspresnet50  0.8883


100%|██████████| 1/1 [00:00<00:00,  1.28it/s]


Unnamed: 0,image_id,label
0,2216849948.jpg,2
1,1000015157.jpg,1
2,1000201771.jpg,3


In [243]:
final_out_ensemble

[array([2, 1, 3])]

In [None]:
params["tta"] = True
create_pseudo = False
num_tta = len(test_transform_tta)
def pseudo_predict(loader, model, params, valid_test=False):
    preds = {'image_id':[],
             'label': []}
    model.eval()
    stream = tqdm(loader)
    
    with torch.no_grad():
        for i, data in enumerate(stream, start=1):
            ## TTA output
            tta_output = []   
            for i, image in enumerate(data["images"]):
                tta_output.append(model(image))
            tta_output = torch.stack(tta_output, dim=0).mean(dim=0)
            output = torch.softmax(tta_output, dim=1)
            pred = torch.argmax(output, dim=1).cpu().numpy()
            for j,p in enumerate(pred):
                if output[j][p] > 0.7:
                    preds['image_id'].append(data["image_ids"][0][j])
                    preds['label'].append(p)
    return preds
if create_pseudo:
    ## add external data
    test_external_dataset = TestDataset(test_external, transform=test_transform_tta, valid_test=True)
    test_external_loader = DataLoader(
        test_external_dataset, batch_size=params['batch_size'], shuffle=False, num_workers=4, pin_memory=True,
    )  
    pred_test_external = pseudo_predict(test_external_loader, model, params, valid_test=True)
    pseudo = pd.DataFrame(pred_test_external)
    pseudo.to_csv(f'{root}/external/test_external_pseudo.csv' ,index=False)
    visualize_class_dis(pseudo, 'external pseudo data')