In [None]:
!pip install efficientnet_pytorch
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import cv2
import numpy as np
import time
import random
from tqdm import tqdm_notebook as tqdm
import pandas as pd
import os

from torchvision import models

from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from albumentations.pytorch import ToTensorV2

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

In [None]:
#parameters
train_dir= '../input/cassava-leaf-disease-classification/train_images'
test_dir = '../input/cassava-leaf-disease-classification/test_images'
cfg = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'epochs':30,
    'batch_size':42,
    'lr':0.0001,
    'input_size':256,
    
}

In [None]:
imagenames = [name for name in os.listdir(train_dir)]
csv = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
print(csv.head(10))
print(csv['label'].unique())

In [None]:
def get_train_transforms():
    return Compose([
            RandomResizedCrop(cfg['input_size'], cfg['input_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.2),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        ], p=1.)

def get_valid_transforms():
    return Compose([
            CenterCrop(cfg['input_size'], cfg['input_size'], p=1.),
            Resize(cfg['input_size'], cfg['input_size']),
        ], p=1.)


def normalize_and_to_tensor(img):
    transform = Compose([Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                       ToTensorV2(p=1.0)],p=1.0)
    return transform(image=img)['image']

class CASSAVA(Dataset):
    def __init__(self,
                 imagenames,
                 csv,
                 root_dir,
                 input_size=cfg['input_size'],
                 transforms=None,
                 train=True):
        self.imagenames = imagenames
        self.csv = csv
        self.root_dir = root_dir
        self.input_size = input_size
        self.transforms = transforms
        self.train = train
    def __len__(self):
        return len(self.imagenames)
    def get_onehot(self,label):
        onehot = np.zeros(5)
        onehot[label] = 1
        return onehot
    def __getitem__(self,idx):
        imagename = self.imagenames[idx]
        label = self.csv[self.csv['image_id']==imagename]['label']
        label = self.get_onehot(label)
        image = cv2.imread(self.root_dir+'/'+imagename)
        image = cv2.resize(image,(self.input_size,self.input_size))
        if self.transforms is not None:
            image_aug = self.transforms(image=image)['image']
        else:
            image_aug = image
        image_aug = normalize_and_to_tensor(image_aug)
        label = torch.from_numpy(label)
        return image_aug,label

In [None]:
train_transforms = get_train_transforms()
t_dataset = CASSAVA(imagenames[300:],csv,train_dir,transforms = train_transforms)
train_loader = DataLoader(dataset=t_dataset, batch_size=cfg['batch_size'], shuffle=True, num_workers=2)

valid_transforms = get_valid_transforms()
v_dataset = CASSAVA(imagenames[:300],csv,train_dir,transforms = valid_transforms)
valid_loader = DataLoader(dataset=v_dataset, batch_size=cfg['batch_size'], shuffle=False, num_workers=0)

In [None]:
backbone = models.resnet50(pretrained=False)
modules = list(backbone.children())[:-2]
backbone = nn.Sequential(*modules)
print(backbone)

In [None]:

class BACKBONE(nn.Module):
    def __init__(self):
        super().__init__()
        self.effn = backbone
        self.average = nn.AvgPool2d((8,8))
        self.flatten = nn.Flatten()        
    def forward(self,x):
        x = self.effn(x)
        x = self.average(x)
        x = self.flatten(x)
        #x = F.relu(self.projection(x))
        return x
    
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(2048,1024)
        self.dense2 = nn.Linear(1024,5)
    def forward(self,X):
        X = F.relu(self.dense1(X))
        return self.dense2(X)

class CLASSIFIER(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = BACKBONE()
        self.backbone.load_state_dict(torch.load('../input/simsaim-weights/I_am_trained_14.pt',map_location='cpu'))
        self.mlp = MLP()
    def forward(self,X):
        X = self.backbone(X)
        X = self.mlp(X)
        return F.softmax(X,dim=-1)
    
model = CLASSIFIER()
model.to(cfg['device'])
x = torch.randn((1,3,cfg['input_size'],cfg['input_size']))
x = x.to(cfg['device'])

y = model(x)
print(y.size())

In [None]:
def Loss(y_true,y_pred):
    l = -y_true*((1-y_pred)**4.0)*torch.log(y_pred)
    l = torch.sum(l,axis=-1)
    return l

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)
model.to(cfg['device'])

for epoch in range(cfg['epochs']):
    
    #epoch parameters
    epoch_loss = 0
    model.train()

    start = time.time()
    
    for i,(image1,label) in enumerate(tqdm(train_loader,total=train_loader.__len__(),ncols = 500)):
        
        image1 = image1.to(cfg['device'])
        label = label.to(cfg['device'])
        
        #represent
        scores = model(image1)
   
        #Focal Loss
        loss = Loss(label,scores)
        
        #backprop
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        
        epoch_loss+=loss.mean().item()
    

    scheduler.step(epoch_loss/(i+1))
    print('Epoch {:03}: | Loss: {:.3f} | Training time: {}'.format(
            epoch + 1, 
            epoch_loss/(i+1), 
            str(time.time() - start)[:7],
    torch.save(model.state_dict(), 'trained_{}.pt'.format(epoch))))


In [None]:
'''    #validation
    model.eval()
    final_acc = 0
    count=0
    for image,labels in tqdm(valid_loader,total=valid_loader.__len__(),ncols=500):
        image1 = image.to(cfg['device'])
        scores = model(image1)
        preds = scores.max(dim=-1)[0].detach().cpu().numpy()
        labels = labels.max(dim=-1)[0].detach().cpu().numpy()
        acc = np.sum(preds==labels)
        print(acc,labels.shape[0])
        final_acc+=acc
        count+=labels.shape[0]'''

In [None]:
510*42