![](https://storage.googleapis.com/kaggle-competitions/kaggle/13836/logos/header.png?t=2020-10-01-17-22-54)

Motivation: This is meant to be an easy submission Notebook to be used with the CassavaNet Training/Starter Notebook. Any model is easy to load and run, and they run TTA efficiently (as a batch over the models).

Training Notebook: https://www.kaggle.com/capiru/cassavanet-starter-easy-gpu-tpu-cv-0-9

# Library

In [None]:
# ====================================================
# Library
# ====================================================
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

import random
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import models as tvmodels
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
import time

import math
from matplotlib.pyplot import imread
import albumentations as A
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
from sklearn.model_selection import GroupKFold, StratifiedKFold

import time
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Config

In [None]:
# ====================================================
# Config
# ====================================================
DATA_PATH = '../input/cassava-leaf-disease-classification/'
TRAIN_DIR = DATA_PATH + 'train_images/'
TEST_DIR = DATA_PATH + 'test_images/'
MODEL_PATH = '../input/cassavanet-baseline-models/'

N_TTA = 8

HEIGHT = 512
WIDTH = 512
CHANNELS = 3

N_CLASSES = 5

MODEL_LIST = [0,1,2,3,4,5,6]

IMG_MEAN = [0.485, 0.456, 0.406] #Mean for normalization Transform cassava = [0.4303, 0.4967, 0.3134] imgnet = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225] #STD for normalization Transform cassava = [0.2142, 0.2191, 0.1954] imgnet = [0.229, 0.224, 0.225]

# Seed

In [None]:
# ====================================================
# Seed
# ====================================================
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

SEED = 1234
seed_everything(SEED)  
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model

In [None]:
# ====================================================
# Model
# ====================================================
class CassavaNet(nn.Module):
    def __init__(self, model_name=None, pretrained=False):
        super().__init__()
        self.model_name = model_name
        if model_name == 'deit_base_patch16_224' or model_name == 'deit_base_patch16_384':
            self.model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=pretrained)
        else:
            self.model = timm.create_model(model_name, pretrained=pretrained)
        if 'efficientnet' in model_name:
            self.n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(self.n_features, N_CLASSES)
        elif model_name == 'vit_large_patch16_384' or model_name == 'deit_base_patch16_224' or model_name == 'deit_base_patch16_384':
            self.n_features = self.model.head.in_features
            self.model.head = nn.Linear(self.n_features, N_CLASSES)
        elif 'resnext' in model_name:
            self.n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(self.n_features, N_CLASSES)
        
    def forward(self, x):
        return self.model(x)
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False
            
        if 'efficientnet' in self.model_name:
            for param in self.model.classifier.parameters():
                param.requires_grad = True
        elif self.model_name == 'vit_large_patch16_384' or 'deit_base_patch16_224':
            for param in self.model.head.parameters():
                param.requires_grad = True
        elif 'resnext' in self.model_name:
            for param in self.model.fc.parameters():
                param.requires_grad = True
            
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
class GetData(Dataset):
    def __init__(self, Dir, FNames, labels,Type):
        self.dir = Dir
        self.fnames = FNames
        self.lbs = labels
        self.type = Type
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):
        x = imread(os.path.join(self.dir, self.fnames[index]))
        if "train" in self.type:
            aug_data = train_transforms(image = x)
            return aug_data['image'], self.lbs[index]            
        elif "valid" in self.type:
            aug_data = valid_transforms(image = x)
            return aug_data['image'], self.lbs[index]
        elif "tr-tst" in self.type:
            return x, self.lbs[index]
        elif "test" in self.type:
            return x, self.fnames[index]

# Augmentation

In [None]:
# ====================================================
# Augmentation
# ====================================================
Aug_Norm = A.Normalize(mean=IMG_MEAN, std=IMG_STD, max_pixel_value=255.0, p=1.0)
test_aug = Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p = 1.0),
            A.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.2, hue=0.00, always_apply=False, p=1.0),
            A.RandomCrop(height= HEIGHT, width = WIDTH,always_apply=True, p=1.0),
            Aug_Norm,
            ToTensorV2(p=1.0)
        ], p=1.)

In [None]:
os.listdir(MODEL_PATH)

# Model Loading

In [None]:
# ====================================================
# Model Loading
# ====================================================
models = []
count = 0
for model_fpath in os.listdir(MODEL_PATH):
    if count in MODEL_LIST:
        print("Model Loaded:",model_fpath)
        model_name_split = model_fpath.split('_f')[0]
        model = CassavaNet(model_name_split,pretrained = False)
        info = torch.load(MODEL_PATH + model_fpath,map_location = torch.device(DEVICE))
        model.load_state_dict(info)
        models.append(model)
    count+=1

In [None]:
#This sets up the submission dataframe


submission = pd.DataFrame()
list_files = os.listdir(TEST_DIR)
submission['image_id'] = pd.Series(list_files)
submission['label'] = 0
submission.head()

# TTA

In [None]:
# ====================================================
# TTA
# ====================================================
start_time = time.time()
BATCH_SIZE = 1
test_set = GetData(TEST_DIR,submission['image_id'], submission['label'], Type = 'test')
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=8,pin_memory = True)
with torch.no_grad():
    for i, (images,labels) in enumerate(test_loader):
        voting = np.zeros((len(models),N_TTA,N_CLASSES))
        aug_images = np.zeros((N_TTA,CHANNELS,HEIGHT,WIDTH))
        for aug_no in range(N_TTA):
            img_np = images.numpy()

            aug_data = test_aug(image = np.reshape(img_np,(600,800,CHANNELS)))
            aug_images[aug_no,:,:,:] = aug_data['image'].numpy()
        aug_images = torch.from_numpy(aug_images).to(torch.float32).to(DEVICE)
        for model_no in range(len(models)):
            model = models[model_no]
            model = model.to(DEVICE)
            model.eval()            

            logits = model(aug_images)
            voting[model_no,:,:] = F.softmax(logits).cpu().numpy()

        voting = np.sum(voting,axis = 1) / N_TTA
        voting = np.sum(voting,axis = 0) / len(models)
        label = np.argmax(voting)
        submission['label'].loc[submission['image_id'] == labels[0]] = label
print(time.time()-start_time)

In [None]:
submission.to_csv('submission.csv',index=False)
submission.head()