## References

1. Focal loss : https://towardsdatascience.com/a-loss-function-suitable-for-class-imbalanced-data-focal-loss-af1702d75d75
2. Past plant competition : https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8/discussion/242275

In [2]:
!pip install -qq albumentations==1.0.3
!pip install wandb --upgrade
!pip install timm
!pip install torch==1.10.0

# !pip install pretrainedmodels
# !pip install --upgrade efficientnet-pytorch

In [3]:
import pandas as pd
import numpy as np
import cv2
import os
import functools
import seaborn as sns
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch
import timm
# from efficientnet_pytorch import EfficientNet
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import AveragePrecision, Recall, F1
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import albumentations

from albumentations.pytorch.transforms import ToTensorV2

import wandb
wandb.login()
# b8097c5d4e834adf72ecb12daa4275aaca4acf2c

In [4]:
data = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
sub_csv = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
data.shape, sub_csv.shape

In [5]:
len(list(data.cultivar.unique()))

In [6]:
class_counts = dict(data.cultivar.value_counts())
keys = list(class_counts.keys())
values = list(class_counts.values())
plt.figure(figsize=(10, 20))
sns.barplot(values, keys)

In [7]:
#None will be filled later during the process.
class Config:
    #general.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    proj_name = 'Sorghum Kaggle Competition'
    submission_file = 'submission.csv'
    
    #dataset params.
    train_folder = '../input/sorghum-id-fgvc-9/train_images'
    test_folder = '../input/sorghum-cultivar-identification-512512/test'
    sub_folder = '../input/sorghum-id-fgvc-9/test'
    val_percent = 0.2
    
    #model params.
    model_name = 'tf_efficientnet_b4_ns' #tf_efficientnet_b0 #vit_base_patch16_224
    img_dim = 512
    out_features = 100
    in_channels = 3
    pretrained = True
    dropout = 0.5
    
    #train params.
    epochs = 10
    batch_size = 8
    learning_rate = 2e-5
    label_smoothing = 0 #loss.
    penalty = 2 #loss.
    transform = None
    

In [8]:
class sorghumDataset(Dataset):
    
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        return (self.images[index], torch.tensor(self.labels[index]))
    
class sorghumTestDataset(Dataset):
    
    def __init__(self, images):
        self.images = images
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        return self.images[index]

In [9]:
class sorghumModel(nn.Module):
    
    def __init__(self, model_name, out_features, in_channels=3, drop_prob=0.2, pretrained=True):
        super(sorghumModel, self).__init__()
        self.out_features = out_features
        self.in_channels = in_channels
        
        #         self.model = EfficientNet.from_pretrained(model_name, 
#                                      num_classes=out_features, 
#                                      dropout_rate=drop_prob, 
#                                      image_size=Config.img_dim)
#         self.pre_model.classifier = nn.Linear(in_features, 1024, bias=True)
        
        self.pre_model = timm.create_model(model_name, pretrained=pretrained, in_chans=in_channels)
        in_features = self.pre_model.classifier.in_features #head or fc or classifier

        self.pre_model.classifier = nn.Sequential(
            nn.Linear(in_features, in_features, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_prob),
            nn.Linear(in_features, self.out_features, bias=True)
        )
        
#         self.cnn_head = nn.Sequential(
#             nn.Linear(1024, 512, bias=True),
#             nn.ReLU(),
#             nn.Dropout(drop_prob),
#             nn.Linear(512, 64, bias=True),
#             nn.ReLU(),
#             nn.Dropout(drop_prob),
#             nn.Linear(64, self.out_features, bias=True),
#         )
        
        self.drop_layer = nn.Dropout(drop_prob)
        
    def forward(self, image):
#         image_feats = self.pre_model(image)
#         image_feats = self.drop_layer(image_feats)
#         preds = self.cnn_head(image_feats)
        preds = self.pre_model(image)
        return preds

In [10]:
#            albumentations.Blur(blur_limit=3, always_apply=False, p=0.5),
# albumentations.Resize(DIM,DIM),
def train_transform_object(DIM = Config.img_dim):
    return albumentations.Compose(
        [
            albumentations.Resize(DIM,DIM),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1, 0.1),
                contrast_limit=(-0.1, 0.1), p=0.5
            ),
            albumentations.Flip(p=0.5),
            albumentations.Rotate (limit=90, always_apply=False, p=0.5),
            albumentations.CenterCrop (height=Config.img_dim, width=Config.img_dim, always_apply=False, p=0.5),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0),
        ]
    )

def valid_transform_object(DIM = 384):
    return albumentations.Compose(
        [
            albumentations.Resize(DIM,DIM),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0)
        ]
    )

def collate_fn(batch, process):
    images, labels = [], []
    for sample in batch:
        im_name, im_label = sample[0], sample[1]
        im = cv2.imread(os.path.join(Config.train_folder, im_name), 1)
        if Config.transform is not None:
            if process == 'training':
                im = Config.transform['train_transform'](image=im)['image']
            else:
                im = Config.transform['valid_transform'](image=im)['image']
        images.append(im)
        labels.append(im_label)
    images_tensor = torch.stack(images)
    labels_tensor = torch.stack(labels)
    return images_tensor, labels_tensor

def collate_test_fn(batch):
    images = []
    for sample in batch:
        im_name = sample
        im = cv2.imread(os.path.join(Config.test_folder, im_name), 1)
        if Config.transform is not None:
            im = Config.transform['valid_transform'](image=im)['image']
        images.append(im)
    images_tensor = torch.stack(images)
    return images_tensor

In [11]:
#https://github.com/gokulprasadthekkel/pytorch-multi-class-focal-loss/blob/master/focal_loss.py

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=Config.penalty,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [12]:
def train_one_epoch(train_loader, model, epoch, criterion, optimizer):
    model.train()
    stream = tqdm(train_loader)
    total_loss = 0
    images_done = 0
    for i, (im_tensors, tar_tensors) in enumerate(stream, start=1):
        im_tensors = im_tensors.to(Config.device, non_blocking=True)
        tar_tensors = tar_tensors.float().view(-1, 1).squeeze(1).type(torch.LongTensor).to(Config.device, non_blocking=True)

        output = model(im_tensors)
        
        loss = criterion(output, tar_tensors)
        total_loss += float(loss)*len(im_tensors)
        images_done += len(im_tensors)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return total_loss/images_done
        
def validate(val_loader, model, epoch, criterion):
    model.eval()
    stream = tqdm(val_loader)
    final_targets = []
    final_outputs = []
    total_loss = 0
    images_done = 0
    with torch.no_grad():
        for i, (im_tensors, tar_tensors) in enumerate(stream, start=1):
            im_tensors = im_tensors.to(Config.device, non_blocking=True)
            tar_tensors = tar_tensors.float().view(-1, 1).squeeze(1).type(torch.LongTensor).to(Config.device, non_blocking=True)
    
            output = model(im_tensors)

            loss = criterion(output, tar_tensors)
            total_loss += float(loss)*len(im_tensors)
            images_done += len(im_tensors)
            
            target = (tar_tensors.detach().cpu().numpy()).tolist()
            output = (output.detach().cpu().numpy()).tolist()
            
            final_targets.extend(target)
            final_outputs.extend(output)

    return total_loss/images_done, torch.tensor(final_targets), torch.tensor(final_outputs)

def test(model, test_loader, id_class):
    sub_csv = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
    
    model.eval()
    stream = tqdm(test_loader)
    final_outputs = []
    with torch.no_grad():
        for i, im_tensors in enumerate(stream, start=1):
            im_tensors = im_tensors.to(Config.device, non_blocking=True)
            output = model(im_tensors)            
            output = (output.detach().cpu().numpy()).tolist()
            final_outputs.extend(output)
    
    classes = cvt_classes(torch.tensor(final_outputs))
    sub_csv['cultivar'] = classes
    sub_csv['cultivar'] = sub_csv['cultivar'].map(id_class)
    return sub_csv

In [13]:
def remove_missing_images(data):
    images = data['image'].values
    indices = []
    for i in range(data.shape[0]):
        im = data.image.iloc[i]
        if not os.path.exists(os.path.join(Config.train_folder, im)):
            indices.append(i)
    data = data.drop(indices, axis=0).reset_index(drop=True)
    return data

def setup_training(hyp):
    data = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
    data = remove_missing_images(data)
    
    classes = list(data.cultivar.unique())
    id_class = dict([(k, v) for k, v in enumerate(classes)])
    class_id = dict([(v, k) for k, v in id_class.items()])
    data['cultivar'] = data['cultivar'].map(class_id)

    images, labels = data.image.values, data.cultivar.values
    train_data, val_data, train_labels, val_labels = train_test_split(images, 
                                                                      labels,
                                                                      test_size=Config.val_percent, 
                                                                      stratify=labels,
                                                                      random_state=42,
                                                                      shuffle=True)
    
    train_dataset = sorghumDataset(train_data, train_labels)
    val_dataset = sorghumDataset(val_data, val_labels)
    
    train_collate = functools.partial(collate_fn, process='training')
    valid_collate = functools.partial(collate_fn, process='validation')
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=Config.batch_size, collate_fn=train_collate)
    val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=Config.batch_size, collate_fn=valid_collate)
    
    model = sorghumModel(model_name=hyp.model_name, 
                         out_features=hyp.out_features, 
                         in_channels=hyp.in_channels, 
                         drop_prob=hyp.drop_prob,
                         pretrained=hyp.pretrained)
    model.to(Config.device)
#     loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp.label_smoothing)
    loss_fn = FocalLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=hyp.lr_rate, weight_decay=1e-6, amsgrad=False)
    
    return id_class, model, train_dataloader, val_dataloader, loss_fn, optimizer

def setup_testing():
    test_data = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
    test_images = test_data.filename.values
    test_dataset = sorghumTestDataset(test_images)
    test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=Config.batch_size, collate_fn=collate_test_fn)
    return test_dataloader

In [14]:
def cvt_classes(outputs):
    outputs = torch.nn.functional.softmax(outputs)
    outputs = outputs.detach().cpu().numpy()
    outputs = outputs.argmax(axis=1)
    return torch.tensor(outputs).int()

def train(model, tr_loader, val_loader, criterion, optimizer, params):
    wandb.watch(model, criterion, log='all', log_freq=10)
    for epoch in range(params.epochs):
        train_loss = train_one_epoch(tr_loader, model, epoch, criterion, optimizer)
        val_loss, targets, predictions = validate(val_loader, model, epoch, criterion)
        outputs = cvt_classes(predictions)
        accuracy = accuracy_score(outputs, targets) 
        wandb.log({ 'epoch' : epoch, 'train_loss' : train_loss, 'val_loss' : val_loss, 'accuracy' : accuracy }, step=epoch)
        torch.save(model.state_dict(), '{}_epoch{}.pth'.format(params.model_name, epoch))

def model_pipeline(train_parameters):
    with wandb.init(project=Config.proj_name, config=train_parameters):
        parameters = wandb.config
        id_class, model, train_loader, val_loader, criterion, optimizer = setup_training(parameters)
        train(model, train_loader, val_loader, criterion, optimizer, parameters)
        test_loader = setup_testing()
        result = test(model, test_loader, id_class)
        result.to_csv(Config.submission_file, index=False)

In [15]:
train_params = {
    'model_name' : Config.model_name,
    'out_features' : Config.out_features,
    'in_channels' : Config.in_channels,
    'drop_prob' : Config.dropout,
    'pretrained' : Config.pretrained,
    'epochs' : Config.epochs,
    'lr_rate' : Config.learning_rate,
    'label_smoothing' : Config.label_smoothing
}

Config.transform = {
    'train_transform' : train_transform_object(),
    'valid_transform' : valid_transform_object()
}

model_pipeline(train_params)

In [35]:
data = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
# x = remove_missing_images(data)
classes = list(data.cultivar.unique())

In [36]:
data.shape, x.shape

In [37]:
len(classes)

In [None]:
#label smoothing. https://github.com/OpenNMT/OpenNMT-py/blob/e8622eb5c6117269bb3accd8eb6f66282b5e67d9/onmt/utils/loss.py#L186

class LabelSmoothingLoss(nn.Module):
    """
    With label smoothing,
    KL-divergence between q_{smoothed ground truth prob.}(w)
    and p_{prob. computed by model}(w) is minimized.
    """
    def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
        assert 0.0 < label_smoothing <= 1.0
        self.ignore_index = ignore_index
        super(LabelSmoothingLoss, self).__init__()

        smoothing_value = label_smoothing / (tgt_vocab_size - 2)
        one_hot = torch.full((tgt_vocab_size,), smoothing_value)
        one_hot[self.ignore_index] = 0
        self.register_buffer('one_hot', one_hot.unsqueeze(0))

        self.confidence = 1.0 - label_smoothing

    def forward(self, output, target):
        """
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        """
        model_prob = self.one_hot.repeat(target.size(0), 1)
        model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
        model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)

        return F.kl_div(output, model_prob, reduction='sum')

#reduce image size.
#train for longer.
#increase deapth.