In [None]:
import random
import os
import pandas as pd
import numpy as np
import torch
import cv2
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
from torch import nn
import torchvision.models as models


from sklearn.model_selection import GroupKFold, StratifiedKFold

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform
from tqdm import tqdm
import time
from sklearn.metrics import log_loss

In [None]:
TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'



train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
train.head()

CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 128,
    'epochs': 1,
    'train_bs': 16,
    'valid_bs': 32,
    'size': 256,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2,
    'verbose_step': 1,
    'device': 'cuda:0'
}


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    

    
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            #Resize(CFG.size, CFG.size),
            RandomResizedCrop(CFG['size'], CFG['size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
            ])

    elif data == 'valid':
        return Compose([
            Resize(CFG['size'], CFG['size']),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])



    
    
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image


def prepare_dataloader(df, trn_idx, val_idx, data_root):
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = TrainDataset(train_, transform=get_transforms(data='train'))
    valid_ds = TrainDataset(valid_, transform=get_transforms(data='valid'))
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader


class CassvaImgClassifier(nn.Module):
    def __init__(self, n_class, pretrained=False):
        super().__init__()
        self.model = models.resnext50_32x4d(pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x


        

In [None]:
def train_one_epoch(epoch, model, optimizer,criterion, train_loader, device, train_data):
    print('Training')
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    for i, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
        data, target = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == target).sum().item()
        loss.backward()
        optimizer.step()
        
    train_loss = train_running_loss/len(train_loader.dataset)
    train_accuracy = 100. * train_running_correct/len(train_loader.dataset)    
    return train_loss, train_accuracy

        
def valid_one_epoch(epoch, model, criterion, val_loader, device, val_data):
    t = time.time()
    print('Validating')
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=int(len(val_data)/val_loader.batch_size)):
            data, target = data[0].to(device).float(), data[1].to(device).long()
            outputs = model(data)
            loss = criterion(outputs, target)
            
            val_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            val_running_correct += (preds == target).sum().item()
        
        val_loss = val_running_loss/len(val_loader.dataset)
        val_accuracy = 100. * val_running_correct/len(val_loader.dataset)        
        return val_loss, val_accuracy


In [None]:
if __name__ == '__main__':
        
    seed_everything(CFG['seed'])
    
    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
    validation_auc = 0
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        if fold>0:
            break
        print('Training with {} started'.format(fold))

        print(len(trn_idx), len(val_idx))
        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root=TRAIN_PATH)
        device = torch.device(CFG['device'])
        model = CassvaImgClassifier(train.label.nunique(), pretrained=False).to(device)
        scaler = GradScaler()   
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=CFG['epochs']-1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
        #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=25, 
        #                                                max_lr=CFG['lr'], epochs=CFG['epochs'], steps_per_epoch=len(train_loader))
        
        criterion = torch.nn.CrossEntropyLoss()#MyCrossEntropyLoss().to(device)
        
        for epoch in range(CFG['epochs']):
            
            train_loss , train_auc = train_one_epoch(epoch, model, optimizer, criterion, train_loader, device,trn_idx)
            
            val_loss, val_auc = valid_one_epoch(epoch, model, criterion, val_loader, device, val_idx)
            print(f"=========Epoch: {epoch}/{CFG['epochs']}============")
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_auc:.2f}")
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_auc:.2f}')
            if val_auc>validation_auc:
                print("saving the best model")
                torch.save(model.state_dict(),'{}_test'.format(CFG['model_arch']))
                validation_auc = val_auc
            

In [None]:
'''
folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
validation_loss = 99999

for fold, (trn_idx, val_idx) in enumerate(folds):
    if fold>0:
        break
    print('Training with {} started'.format(fold))

    print(len(trn_idx), len(val_idx))
    train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root=TRAIN_PATH)

'''

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#loading Test Data
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images'))
test_ds = TestDataset(test, transform=get_transforms(data='valid'))

#loading the model

model = CassvaImgClassifier(train.label.nunique(), pretrained=False)
model.load_state_dict(torch.load('../input/pre-trained-model/tf_efficientnet_b4_ns_test'))
if torch.cuda.is_available():
    model.cuda()


tst_loader = torch.utils.data.DataLoader(
    test_ds,
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=True,
)

model.eval()
tk0 = tqdm(enumerate(tst_loader), total=len(tst_loader))
pred_label = [] 
for i, (image) in tk0:
    imgs = image.to(device)
    image_preds = model(imgs) 
    ps = torch.exp(image_preds)
    probab = list(ps.cpu()[0])
    pred_label.append(probab.index(max(probab)))


In [None]:
test['label'] = pred_label

test[['image_id', 'label']].to_csv('submission.csv', index=False)