In [None]:
# Put testing data in data loader, wrap data loader in tqdm for loop, evaluate each batch and save
# predictions in an array. Copy array to a dataframe that has testing photo IDs. Save to submission.csv
package_paths = [
    '../input/pytorch-image-models/pytorch-image-models-master' #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
]
import sys; 

for pth in package_paths:
    sys.path.append(pth)

In [None]:
'''
IMPORTS
'''

import cv2
import torch
import os
from torch import nn
from datetime import datetime
import time
import random
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm


import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SequentialSampler, RandomSampler, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F

import timm

import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.head()

In [None]:
'''
CONFIGURATION
'''

config = {
    'seed': 419,
    'img_size': 512,
    'tta':3,
    'num_folds': 5,
    'num_classes':5,
    
    # input_size = 3, 380, 380. pool_size = 12, 12.
    # DOC: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py
    'model_arch':'tf_efficientnet_b4_ns',    
    
    'train_bs':16,
    'valid_bs':32,
    'test_bs': 4,
    'num_workers': 1,
    'epochs':10,
    'device':'cuda:0',
    
    'T1':0.2,
    'T2':1.0,
    'label_smooth': 0.2,
    
    'lr':1e-4,
    'min_lr':1e-6,
    'T_0': 10,
    'weight_decay':1e-6,
    'ep_patience':4,
    'factor':0.2,
    'num_workers':2,
    #'accum_iter':2,
    'update_on_batch':True,
    'use_wrs':False
}

In [None]:
'''
HELPER FUNCTIONS
'''

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    #if not im_bgr:
    #    return np.zeros([config['img_size'], config['img_size'], 3])
    #im_rgb = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB)
    im_rgb = im_bgr[:, :, ::-1]
    
    return im_rgb

from albumentations import (
    ShiftScaleRotate, Normalize, Compose, CenterCrop, Resize, HorizontalFlip,
    VerticalFlip, Transpose, RandomResizedCrop, HueSaturationValue, RandomBrightnessContrast,
    CoarseDropout, Cutout
)

from albumentations.pytorch import ToTensorV2

def get_infer_transforms():
    return Compose([
                CenterCrop(config['img_size'], config['img_size'], p=1.0),
                Resize(config['img_size'], config['img_size']),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                ToTensorV2(p=1.0),
        ], p=1.0)

In [None]:
'''
DATASET
'''

class LeafDataset(Dataset):
    def __init__(self, df, img_dir, transforms=None, include_labels=True):
        super().__init__()
        self.df = df     #.reset_index(drop=True).copy()
        self.img_dir = img_dir
        self.transforms = transforms
        self.include_labels = include_labels
        
        if include_labels:
            self.labels = self.df['label'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index: int):
        img = get_img("{}/{}".format(self.img_dir, self.df.loc[index]['image_id']))
        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.include_labels:
            label = self.labels[index]
            return img, label
        else:
            return img;

In [None]:
'''
MODEL
'''

class LeafDiseaseClassifier(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, num_classes)

    
    def forward(self, x):
        x = self.model(x)
        return x
        
    def freeze_batch_norm(self):
        layers = [mod for mod in self.model.children()]
        for layer in layers:
            if isinstance(layer, nn.BatchNorm2d):
                #print(layer)
                for param in layer.parameters():
                    param.requires_grad = False
                
            elif isinstance(layer, nn.Sequential):
                for seq_layers in layer.children():
                    if isinstance(layer, nn.BatchNorm2d):
                        #print(layer)
                        param.requires_grad = False

In [None]:
'''
MAIN
'''

if __name__ == '__main__':
    seed_everything(config['seed'])
    
    test = pd.DataFrame()
    test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
    test_ds = LeafDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_infer_transforms(), include_labels=False)
    #print(len(test_ds))
    
    test_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=config['test_bs'],
        num_workers=config['num_workers'],
        shuffle=False,
    )
    
    device = torch.device(config['device'])
    model = LeafDiseaseClassifier(config['model_arch'], config['num_classes']).to(device)
    model.load_state_dict(torch.load('../input/effnet-b4/tf_efficientnet_b4_ns_Fold4_Epoch5_Acc_0.8952.pth'))
    model.eval()
    
    preds = []
    #print(len(test_loader))
    pbar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, (test_batch) in pbar:
        test_batch = test_batch.to(device).float()
        
        test_preds = model(test_batch)
        preds += [torch.softmax(test_preds, 1).detach().cpu().numpy()]
        
        #if step > 5000:
        #    break
        
    preds = np.concatenate(preds, axis=0)
    #print(preds)
    #print(np.argmax(preds, axis=1))
    #del model
    #torch.cuda.empty_cache()

In [None]:
test['label'] = np.argmax(preds, axis=1)
test.head()

test.to_csv('submission.csv', index=False)

In [None]:
del model
torch.cuda.empty_cache()