In [None]:
TRAINING = False
MODEL_SAVE_PATH = "model.pt"
MODEL_LOAD_PATH = "../input/cassava-snapmix/model.pt"

In [None]:
!pip install ../input/timm-package/timm-0.1.26-py3-none-any.whl

In [None]:
import random
import numpy as np
import os
import torch
import pandas as pd
import albumentations as A
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import StratifiedKFold
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import cv2
import torch.nn.functional as F
import timm

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
DATA_PATH = '../input/cassava-leaf-disease-classification/'
NUM_FOLDS = 5
batch_size = 32
EPOCHS = 10
image_size = 512
SNAPMIX_ALPHA = 5.0
SNAPMIX_PCT = 0.5
GRAD_ACCUM_STEPS = 1
TIMM_MODEL = 'resnet50'

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

SEED = 1234
seed_everything(SEED)

# Dataset

In [None]:
class CassavaDataset(torch.utils.data.Dataset):

    def __init__(self, dataframe, root_dir, transforms=None):
        super().__init__()
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.dataframe)
    
    def get_img_bgr_to_rgb(self, path):
        im_bgr = cv2.imread(path)
        im_rgb = im_bgr[:, :, ::-1]
        return im_rgb

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir, self.dataframe.iloc[idx, 0])
        image = self.get_img_bgr_to_rgb(img_name)
        if self.transforms:
            image = self.transforms(image=image)['image']
        csv_row = self.dataframe.iloc[idx, 1:]
        sample = {
            'image': image, 
            'label': csv_row.label,
        }
        return sample

In [None]:
train_df = pd.read_csv(DATA_PATH + "train.csv")

# Transforms

In [None]:
def train_transforms():
    return Compose([
            A.RandomResizedCrop(image_size, image_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightness (p=0.5),
            A.GridDistortion(p=0.5),
            A.RandomGamma(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)


def valid_transforms():
    return Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

# Model

In [None]:
TIMM_MODEL = 'resnet50'

In [None]:
class CassavaNet(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = timm.create_model(TIMM_MODEL, pretrained=False, checkpoint_path='../input/timm-pretrained-resnet/resnet/resnet50_ram-a26f946b.pth')
        n_features = backbone.fc.in_features
        self.backbone = nn.Sequential(*backbone.children())[:-2]
        self.classifier = nn.Linear(n_features, 5)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward_features(self, x):
        x = self.backbone(x)
        return x

    def forward(self, x):
        feats = self.forward_features(x)
        x = self.pool(feats).view(x.size(0), -1)
        x = self.classifier(x)
        return x, feats

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# SnapMix Augmentations

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [None]:
def get_spm(inputs,target,model):
    imgsize = (image_size, image_size)
    batch_size = inputs.size(0)
    with torch.no_grad():
        output,fms = model(inputs)
        
        clsw = model.classifier
        weight = clsw.weight.data
        bias = clsw.bias.data
        
        weight = weight.view(weight.size(0),weight.size(1),1,1)
        
        fms = F.relu(fms)
        poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze()
        clslogit = F.softmax(clsw.forward(poolfea))
        
        logitlist = []
        for i in range(batch_size):
            logitlist.append(clslogit[i,target[i]])
            
        clslogit = torch.stack(logitlist)

        out = F.conv2d(fms, weight, bias=bias)

        outmaps = []
        for i in range(batch_size):
            evimap = out[i,target[i]]
            outmaps.append(evimap)

        outmaps = torch.stack(outmaps)
        if imgsize is not None:
            outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2))
            outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False)

        outmaps = outmaps.squeeze()

        for i in range(batch_size):
            outmaps[i] -= outmaps[i].min()
            outmaps[i] /= outmaps[i].sum()


    return outmaps,clslogit

In [None]:
def snapmix(inputs, target, alpha, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(inputs.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    wfmaps,_ = get_spm(inputs, target, model)
    batch_size = inputs.size(0)
    lam = np.random.beta(alpha, alpha)
    lam1 = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(batch_size).cuda()
    wfmaps_b = wfmaps[rand_index,:,:]
    target_b = target[rand_index]

    same_label = target == target_b
    bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
    bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(inputs.size(), lam1)

    area = (bby2-bby1)*(bbx2-bbx1)
    area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1)

    if  area1 > 0 and  area>0:
        ncont = inputs[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
        ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True)
        inputs[:, :, bbx1:bbx2, bby1:bby2] = ncont
        lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8)
        lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8)
        tmp = lam_a.clone()
        lam_a[same_label] += lam_b[same_label]
        lam_b[same_label] += tmp[same_label]
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
        lam_a[torch.isnan(lam_a)] = lam
        lam_b[torch.isnan(lam_b)] = 1-lam

    return inputs,target,target_b,lam_a.cuda(),lam_b.cuda()

In [None]:
class SnapMixLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, criterion, outputs, ya, yb, lam_a, lam_b):
        loss_a = criterion(outputs, ya)
        loss_b = criterion(outputs, yb)
        loss = torch.mean(loss_a * lam_a + loss_b * lam_b)
        return loss

# Training

In [None]:
if TRAINING:
    model = CassavaNet()
    print("Created New Model")
else:
    model = torch.load(MODEL_LOAD_PATH)
    print("Loaded Model from", MODEL_LOAD_PATH)

model = model.to(device)

In [None]:
model

In [None]:
criterion = nn.CrossEntropyLoss(reduction='none').to(device)
snapmix_criterion = SnapMixLoss().to(device)

In [None]:
param_groups = [
    {'params': model.backbone.parameters(), 'lr': 1e-2},
    {'params': model.classifier.parameters()},
]

optimizer = torch.optim.SGD(param_groups, lr=1e-1, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1,20,40], gamma=0.1, last_epoch=-1)
scaler = GradScaler()

In [None]:
def get_datasets(train_split, valid_split):
    train_set = train_df.iloc[train_split].reset_index(drop=True)
    valid_set = train_df.iloc[valid_split].reset_index(drop=True)
    
    train_ds = CassavaDataset(dataframe=train_set,
                              root_dir=DATA_PATH + 'train_images',
                              transforms=train_transforms()
                             )
    
    valid_ds = CassavaDataset(dataframe=valid_set,
                              root_dir=DATA_PATH + 'train_images',
                              transforms=valid_transforms())
    
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, 
                                           shuffle=True, num_workers=8, drop_last=True,
                                           pin_memory=True)
    valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, 
                                           shuffle=False, num_workers=8,
                                           pin_memory=True)
    
    return train_dl, valid_dl

In [None]:
def train_single_epoch(train_dl, epoch):
    model.train()
    train_loss = 0
    train_accuracy = 0
    
    with tqdm(total=len(train_dl)) as t:
        for batch_idx, data in enumerate(train_dl,1):
            image, label = data.values()
            X, y = image.to(device).float(), label.to(device).long()
            
            with autocast():
                rand = np.random.rand()
                if rand > (1.0-SNAPMIX_PCT):
                    X, ya, yb, lam_a, lam_b = snapmix(X, y, SNAPMIX_ALPHA, model)
                    outputs, _ = model(X)
                    loss = snapmix_criterion(criterion, outputs, ya, yb, lam_a, lam_b)
                else:
                    outputs, _ = model(X)
                    loss = torch.mean(criterion(outputs, y))
            
            scaler.scale(loss).backward()
            
            if (batch_idx % GRAD_ACCUM_STEPS == 0) or (batch_idx == len(train_dl)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            train_loss += loss.item()
            
            preds = F.softmax(outputs).argmax(axis=1)
            accuracy = (preds==y).sum()/len(y)
            train_accuracy += accuracy.cpu().item()
            
            t.set_description(f"Epoch: {epoch}/{EPOCHS}\tLoss: {train_loss/batch_idx:0.4f}\tAccuracy: {train_accuracy/batch_idx:.4f}")
            t.update()
    
    return train_loss/len(train_dl), train_accuracy/len(train_dl)

In [None]:
def validate_one_epoch(valid_dl):
    model.eval()
    
    val_loss = 0
    scores = []
    
    with torch.no_grad():
        for data in valid_dl:
            image, label = data.values()
            X, y = image.to(device), label.to(device)
            outputs, _ = model(X)
            loss = torch.mean(criterion(outputs, y))
            val_loss += loss.item()
            
            preds = F.softmax(outputs).argmax(axis=1)
            accuracy = (preds==y).sum()/len(y)
            scores.append(accuracy.cpu().item())
    
    return val_loss/len(valid_dl), np.average(scores)

In [None]:
if TRAINING:
    folds = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED).split(np.arange(train_df.shape[0]), train_df.label.values)

    for fold_num, (train_split, valid_split) in enumerate(folds):
        train_dl, valid_dl = get_datasets(train_split, valid_split)        

        best_metric = 0

        for epoch in range(EPOCHS):
            train_loss, train_accuracy = train_single_epoch(train_dl, epoch+1)
            scheduler.step()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            valid_loss, valid_accuracy = validate_one_epoch(valid_dl)

            print(f"Epoch: {epoch+1}/{EPOCHS}\tLoss: {train_loss:0.4f}\tAccuracy: {train_accuracy:.4f}\tValid Loss: {valid_loss:0.4f}\tValid Accuracy: {valid_accuracy:0.4f}")

            if valid_accuracy>best_metric:
                print(f"Accuracy increased from {best_metric} to {valid_accuracy}, Saving Model")
                torch.save(model, MODEL_SAVE_PATH)
                best_metric = valid_accuracy

    print("Best Model Accuracy:", best_metric)
    model = torch.load(MODEL_SAVE_PATH)

In [None]:
torch.save(model, MODEL_SAVE_PATH)

# Prediction

In [None]:
test_images_path = DATA_PATH+"test_images"
test_image_id = os.listdir(test_images_path)

In [None]:
test_df = pd.DataFrame(test_image_id, columns=["image_id"])
test_df['label'] = -1

In [None]:
test_ds = CassavaDataset(dataframe=test_df,
                              root_dir=DATA_PATH + 'test_images',
                              transforms=valid_transforms())

test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, 
                                           shuffle=False, num_workers=8,
                                           pin_memory=True)

In [None]:
model.eval()
labels = []
with torch.no_grad():
    for data in test_dl:
        outputs, _ = model(data['image'].to(device))
        preds = F.softmax(outputs).argmax(axis=1)
        labels.extend(preds.cpu().numpy())

test_df['label'] = labels
test_df.head()

In [None]:
test_df.to_csv("submission.csv", index=False)