In [None]:
package_path = '../input/timm-pytorch-image-models/pytorch-image-models-master' #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
import sys; sys.path.append(package_path)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor , RandomResizedCrop , Compose
from torch.utils.data import random_split
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader , Dataset
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from sklearn import model_selection, metrics
import time
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
%matplotlib inline


In [None]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    #'model_arch': "vit_base_patch16_224",
    'model_arch': 'tf_efficientnet_b4_ns',
    #'model_arch': 'vit_base_patch32_384',
    'img_size': 512,
    'epochs': 4,
    'train_bs': 16,
    'valid_bs': 32,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0'
}

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
device = get_default_device()
device

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
len(df)

In [None]:
DATA_PATH = "../input/cassava-leaf-disease-classification"
TRAIN_PATH = "../input/cassava-leaf-disease-classification/train_images/"
TEST_PATH = "../input/cassava-leaf-disease-classification/test_images/"
MODEL_PATH = (
    "../input/vit-base/jx_vit_base_p16_224-80ecf9dd.pth"
)

# model specific global variables
IMG_SIZE = 224
BATCH_SIZE = 16
LR = 2e-05
GAMMA = 0.7
N_EPOCHS = 3

In [None]:
import cv2

def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb
    
# create image augmentations
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomResizedCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

transforms_train= Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
transforms_valid = Compose([
            CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

transforms_test =  Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

class CassavaDataset(Dataset):
    def __init__(
        self, df, data_root, transforms=None, output_label=True
    ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        img = get_img(path)
        #img = Image.open(path).convert("RGB")
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
            #img = data_transforms(Image.fromarray(img))
            #img = self.transforms(img)

        # do label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img


datapath = '../input/cassava-leaf-disease-classification/train_images/'
df_ds = CassavaDataset(df, datapath, transforms=transforms_train, output_label=True)

In [None]:
#!pip install timm

In [None]:
import timm
print("Available Vision Transformer Models: ")
timm.list_models("vit*")
#timm.list_models("vo*")

In [None]:
class ViTBase16(nn.Module):
    def __init__(self, n_classes, pretrained=False):

        super(ViTBase16, self).__init__()

        self.model = timm.create_model(CFG['model_arch'], pretrained=pretrained)
        #if pretrained:
        #   self.model = timm.create_model(CFG['model_arch'], pretrained=True)
            #self.model.load_state_dict(torch.load(MODEL_PATH))
        #self.model.head = nn.Linear(self.model.head.in_features, n_classes)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, n_classes)

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
#model = ViTBase16(n_classes=5, pretrained=True).to(device)

In [None]:
df.info()

In [None]:
df.label.value_counts().plot(kind="bar")

In [None]:
train_df, valid_df = model_selection.train_test_split(
    df, test_size=0.1, random_state=42, stratify=df.label.values
)
train_dataset = CassavaDataset(train_df, datapath, transforms=transforms_train, output_label=True)
valid_dataset = CassavaDataset(valid_df, datapath, transforms=transforms_valid, output_label=True)

In [None]:
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
test_dataset = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=transforms_test, output_label=False)
len(test_dataset)

In [None]:
img , label = train_dataset[56]
print(tuple(img.shape))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
index = 7989
plt.imshow(np.transpose(train_dataset[index][0].numpy(), (1, 2, 0)))
print("label : ",train_dataset[index][1])

In [None]:
#dataset = train_ds
#val_size = int(0.7*len(dataset))
#train_size = len(dataset) - val_size

#train_ds, val_ds = random_split(dataset, [train_size, val_size])
#len(train_ds), len(val_ds)


#dataset = train_ds
#val_size = int(0.5*len(dataset))
#train_size = len(dataset) - val_size

#train_ds, val_ds = random_split(dataset, [train_size, val_size])
#len(train_ds), len(val_ds)

In [None]:
def prepare_dataloader(df, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    from catalyst.data.sampler import BalanceClassSampler
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=False, do_cutmix=False)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        #print(image_labels.shape, exam_label.shape)
        with autocast():
            image_preds = model(imgs)   #output = model(input)
            #print(image_preds.shape, exam_pred.shape)

            loss = loss_fn(image_preds, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                
                pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

In [None]:
class MyCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean'):
        super().__init__(weight=weight, reduction=reduction)
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

In [None]:
#train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')
#train_dataset = CassavaDataset(df, datapath, transforms=transforms_train, output_label=True)

train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
val_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
#optimizer = torch.optim.Adam(model.parameters(), lr=LR)

#optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=GAMMA,step_size=N_EPOCHS-1)

#scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)

#loss_tr = nn.CrossEntropyLoss().to(device) #MyCrossEntropyLoss().to(device)
#loss_fn = nn.CrossEntropyLoss().to(device)
scaler = GradScaler()   
#for epoch in range(N_EPOCHS):
#            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, schd_batch_update=False)

            #with torch.no_grad():
            #    valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)

#            torch.save(model.state_dict(),'{}_epoch_{}'.format(CFG['model_arch'], epoch))

            

#del model, optimizer, train_loader, val_loader, scaler, scheduler
#torch.cuda.empty_cache()

In [None]:
model = ViTBase16(n_classes=5, pretrained=False)
epoch = 2
model.load_state_dict(torch.load('../input/vit-base/{}_epoch_{}'.format(CFG['model_arch'], epoch)))
model = model.to(device)

test_loader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )

with torch.no_grad():
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, imgs in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]

    image_preds_all = np.concatenate(image_preds_all)    

test['label'] = image_preds_all
test.head()

In [None]:
#vit_*_patch16_224
#0.8621 large (epoch 2) (0.8467, 0.8379 , 0.8621 , 0.8579 , 0.8509 ,0.8561)
#base (0.8318,0.8449,0.8519,0.8537,0.8430,0.8547,0.8561,0.8551,0.8421,0.8542)


In [None]:
#'lr': 1e-4   vit_base_patch32_384
'''

epoch 0 loss: 0.5096: 100%|██████████| 1204/1204 [06:01<00:00,  3.33it/s]
epoch 0 loss: 0.4842: 100%|██████████| 67/67 [00:24<00:00,  2.69it/s]

validation multi-class accuracy = 0.8355


epoch 1 loss: 0.4736: 100%|██████████| 1204/1204 [06:00<00:00,  3.34it/s]
epoch 1 loss: 0.4327: 100%|██████████| 67/67 [00:24<00:00,  2.78it/s]

validation multi-class accuracy = 0.8575


epoch 2 loss: 0.4539: 100%|██████████| 1204/1204 [06:00<00:00,  3.34it/s]
epoch 2 loss: 0.4418: 100%|██████████| 67/67 [00:24<00:00,  2.77it/s]

validation multi-class accuracy = 0.8463


epoch 3 loss: 0.4201: 100%|██████████| 1204/1204 [06:01<00:00,  3.33it/s]
epoch 3 loss: 0.4239: 100%|██████████| 67/67 [00:24<00:00,  2.76it/s]

validation multi-class accuracy = 0.8505


epoch 4 loss: 0.4262: 100%|██████████| 1204/1204 [06:01<00:00,  3.33it/s]
epoch 4 loss: 0.4455: 100%|██████████| 67/67 [00:24<00:00,  2.76it/s]

validation multi-class accuracy = 0.8453


epoch 5 loss: 0.4044: 100%|██████████| 1204/1204 [06:00<00:00,  3.34it/s]
epoch 5 loss: 0.4322: 100%|██████████| 67/67 [00:24<00:00,  2.77it/s]

validation multi-class accuracy = 0.8514


epoch 6 loss: 0.3910: 100%|██████████| 1204/1204 [06:01<00:00,  3.33it/s]
epoch 6 loss: 0.4282: 100%|██████████| 67/67 [00:24<00:00,  2.79it/s]

validation multi-class accuracy = 0.8598


epoch 7 loss: 0.3573: 100%|██████████| 1204/1204 [06:00<00:00,  3.34it/s]
epoch 7 loss: 0.4278: 100%|██████████| 67/67 [00:23<00:00,  2.87it/s]

validation multi-class accuracy = 0.8612


epoch 8 loss: 0.3483: 100%|██████████| 1204/1204 [06:00<00:00,  3.34it/s]
epoch 8 loss: 0.4324: 100%|██████████| 67/67 [00:24<00:00,  2.76it/s]

validation multi-class accuracy = 0.8519


epoch 9 loss: 0.3072: 100%|██████████| 1204/1204 [06:01<00:00,  3.33it/s]
epoch 9 loss: 0.4389: 100%|██████████| 67/67 [00:24<00:00,  2.79it/s]

validation multi-class accuracy = 0.8621

'''

In [None]:
#'lr': 1e-4  

'''
epoch 0 loss: 0.4371: 100%|██████████| 1204/1204 [14:30<00:00,  1.38it/s]
epoch 0 loss: 0.3726: 100%|██████████| 67/67 [00:36<00:00,  1.84it/s]

validation multi-class accuracy = 0.8757


epoch 1 loss: 0.4025: 100%|██████████| 1204/1204 [14:30<00:00,  1.38it/s]
epoch 1 loss: 0.3566: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8780


epoch 2 loss: 0.3661: 100%|██████████| 1204/1204 [14:29<00:00,  1.38it/s]
epoch 2 loss: 0.3388: 100%|██████████| 67/67 [00:34<00:00,  1.92it/s]

validation multi-class accuracy = 0.8911


epoch 3 loss: 0.3390: 100%|██████████| 1204/1204 [14:29<00:00,  1.38it/s]
epoch 3 loss: 0.3506: 100%|██████████| 67/67 [00:35<00:00,  1.91it/s]

validation multi-class accuracy = 0.8771


epoch 4 loss: 0.3354: 100%|██████████| 1204/1204 [14:30<00:00,  1.38it/s]
epoch 4 loss: 0.3348: 100%|██████████| 67/67 [00:35<00:00,  1.87it/s]

validation multi-class accuracy = 0.8883


epoch 5 loss: 0.3473: 100%|██████████| 1204/1204 [14:30<00:00,  1.38it/s]
epoch 5 loss: 0.3494: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8799


epoch 6 loss: 0.3234: 100%|██████████| 1204/1204 [14:29<00:00,  1.38it/s]
epoch 6 loss: 0.3438: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8841


epoch 7 loss: 0.2839: 100%|██████████| 1204/1204 [14:30<00:00,  1.38it/s]
epoch 7 loss: 0.3568: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8748

'''

In [None]:
#'lr': 2e-5

'''
epoch 0 loss: 0.5750: 100%|██████████| 1204/1204 [14:33<00:00,  1.38it/s]
epoch 0 loss: 0.5037: 100%|██████████| 67/67 [00:36<00:00,  1.84it/s]

validation multi-class accuracy = 0.8201


epoch 1 loss: 0.4907: 100%|██████████| 1204/1204 [14:27<00:00,  1.39it/s]
epoch 1 loss: 0.4051: 100%|██████████| 67/67 [00:35<00:00,  1.88it/s]

validation multi-class accuracy = 0.8621


epoch 2 loss: 0.4327: 100%|██████████| 1204/1204 [14:28<00:00,  1.39it/s]
epoch 2 loss: 0.3753: 100%|██████████| 67/67 [00:35<00:00,  1.89it/s]

validation multi-class accuracy = 0.8762


epoch 3 loss: 0.4160: 100%|██████████| 1204/1204 [14:27<00:00,  1.39it/s]
epoch 3 loss: 0.3637: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8813


epoch 4 loss: 0.3774: 100%|██████████| 1204/1204 [14:27<00:00,  1.39it/s]
epoch 4 loss: 0.3495: 100%|██████████| 67/67 [00:35<00:00,  1.90it/s]

validation multi-class accuracy = 0.8883


epoch 5 loss: 0.3772: 100%|██████████| 1204/1204 [14:27<00:00,  1.39it/s]
epoch 5 loss: 0.3554: 100%|██████████| 67/67 [00:34<00:00,  1.91it/s]

validation multi-class accuracy = 0.8822


epoch 6 loss: 0.3493: 100%|██████████| 1204/1204 [14:28<00:00,  1.39it/s]
epoch 6 loss: 0.3519: 100%|██████████| 67/67 [00:36<00:00,  1.84it/s]

validation multi-class accuracy = 0.8808


epoch 7 loss: 0.3564: 100%|██████████| 1204/1204 [14:28<00:00,  1.39it/s]
epoch 7 loss: 0.3431: 100%|██████████| 67/67 [00:35<00:00,  1.89it/s]

validation multi-class accuracy = 0.8879

'''

In [None]:
test.to_csv('submission.csv', index=False)