In [None]:
import os
import zipfile
from copy import deepcopy
import random
import math
import shutil
import gc as garbage
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import functional as FM
from torch.utils.data import Dataset, DataLoader

import torchvision

import PIL
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns

from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_nb

from sklearn.model_selection import KFold, StratifiedKFold
from sklearn import preprocessing

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Utils

In [None]:
def get_pil_image(dirname, path):
    return PIL.Image.open(os.path.join(f'../input/{dirname}', path))

def get_image(dirname, path):
    return np.array(get_pil_image(dirname, path))

def image_histplot(img, **kwargs):
    if not 'figsize' in kwargs: kwargs['figsize'] = (12, 3)
    fig, axes = plt.subplots(1, 4, **kwargs)
    plt.tight_layout()
    axes[0].imshow(img)
    for i in range(3):
        axes[1+i].set_yticks([])
        sns.histplot(img[:,:,i].flatten() / 255, ax=axes[1+i], color="rgb"[i], alpha=0.33)
    plt.show()

def image_gridplot(images, rows=None, cols=None, transform=None, **kwagrs):
    if rows == None and cols == None:
        cols = 6
        rows = math.ceil(len(images) / cols)
    elif rows == None:
        rows = math.ceil(len(images) / cols)
    elif cols == None:
        cols = math.ceil(len(images) / rows)
    if not 'figsize' in kwagrs: kwagrs['figsize'] = (3 * cols, 3 * rows)
    fig = plt.figure(**kwagrs)
    grid = ImageGrid(fig, 111, nrows_ncols=(rows, cols), axes_pad=0.1)
    for ax, im in zip(grid, images):
        ax.axis('off')
        if transform != None: im = transform(im)
        ax.imshow(im)
    plt.show()

class ImageConv:
    def tensor2np(tensor):
        return tensor.permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)

    def np2tensor(nparr):
#         return torch.tensor(nparr).to(Hypers.device).permute(2, 0, 1).float()
        return torch.tensor(nparr).permute(2, 0, 1).float()
    
    def to_plt(x):
        if type(x) == torch.Tensor:
            return ImageConv.tensor2np(x).astype(np.uint8)
        return x.astype(np.uint8)

# Configuration

In [None]:
class Hypers:
    verbose = True
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    image_size = (224, 224)
    
    max_epoches = 80
    batch_size = 128
    learning_rate = 1e-3
    optimizer = torch.optim.Adam
    
    patience = 5


# CUDA set-up
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Dataset

In [None]:
df_train = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
df_train

## Distribution

In [None]:
image_count_df = df_train.groupby('cultivar').count().sort_values('image')
print('Skew: ', image_count_df.skew())

plt.figure(figsize=(20, 8))
plt.xticks(rotation=90)
sns.barplot(x=image_count_df.index, y=image_count_df['image'])
plt.show()

### Color Histogram

In [None]:
for i in range(3):
    image_path, cultivar = df_train.loc[i]
    image_histplot(get_image('sorghum-id-fgvc-9/train_images', image_path), figsize=(15, 3))

**NOTE:** First, it contains dummy data `.DS_STORE`

In [None]:
print('train set has', df_train.isna().sum().max(), 'null(s)')

## Outliers

In [None]:
outliers = ['29-33-477', '29-34-965', '29-36-468', '29-43-957', '29-45-460', '29-46-961', '29-48-469', '29-49-960', '29-51-465', '30-06-475', '30-07-971', '30-09-467', '30-59-251', '31-00-751', '31-02-238', '31-17-229', '31-18-730', '31-20-230', '31-21-751', '31-23-234', '31-24-729', '31-30-733', '31-32-233', '31-33-750', '31-35-254']
outliers = list(map(lambda id: f'2017-06-11__13-{id}.png', outliers))
outliers_img = list(map(lambda id: np.array(get_pil_image('sorghum-id-fgvc-9/train_images', id)), outliers))
image_gridplot(outliers_img, cols=10, figsize=(12, 3))

In [None]:
df_train = df_train.drop(df_train[df_train['image'].isin(outliers)].index)
df_train = df_train.dropna().reset_index(drop=True)

# difficult cases

just in my opinion

In [None]:
hard_pathes = [
    ('sorghum-id-fgvc-9/train_images', '2017-06-02__13-44-49-948.png'),
    ('sorghum-id-fgvc-9/train_images', '2017-06-02__16-46-57-375.png'),
    ('sorghum-id-fgvc-9/test', '181578.png'),
    ('sorghum-id-fgvc-9/test', '13937931.png'),
    ('sorghum-id-fgvc-9/test', '17713602.png'),
    ('sorghum-id-fgvc-9/test', '21517068.png'),
    ('sorghum-id-fgvc-9/test', '45744669.png'),
    ('sorghum-id-fgvc-9/test', '326836917.png'),
    ('sorghum-id-fgvc-9/test', '1368455764.png'),
    ('sorghum-id-fgvc-9/test', '1119726178.png'),
]
hard_imgs = list(map(lambda case: get_image(case[0], case[1]), hard_pathes))
image_gridplot(hard_imgs, cols=5)

## CLAHE (Contrast Limited Adaptive Histogram Equalization)

In [None]:
import cv2

def CLAHE(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    clahe = cv2.createCLAHE(clipLimit=10)
    hsv[:,:,-1] = clahe.apply(hsv[:,:,-1])
    rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return rgb

In [None]:
images = []
for i in [0, 20, 40]:
    img_path, c = df_train.loc[i]
    img = get_image('sorghum-id-fgvc-9/train_images', img_path)
    images.append(img)
    images.append(CLAHE(img))

image_gridplot(images, cols=6)

In [None]:
image_gridplot(list(map(CLAHE, hard_imgs)), cols=5)

In [None]:
image_histplot(images[-2])
image_histplot(images[-1])

Good.

# DataLoader

In [None]:
from sklearn.preprocessing import LabelEncoder

class SorghumDataset(Dataset): 
    def __init__(self, df, is_train=True, transform=None, encoder=LabelEncoder()):
        super().__init__()
        
        self.is_train = is_train
        self.df = df
        self.transform = transform
        self.encoder = encoder
        if self.is_train:
            self.df['label'] = self.encoder.fit_transform(self.df['cultivar'])

    @property
    def data(self):
        return self.df.values
    
    @property
    def targets(self):
        return self.df['label']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.is_train:
            image_path, cultivar, label = self.df.loc[idx, ['image', 'cultivar', 'label']]
        else:
            image_path, cultivar, label = self.df.loc[idx, ['filename', 'cultivar', 'filename']]
        dataset_name = 'sorghum-fgvc9-clahe-256' # 'sorghum-id-fgvc-9'
        image = get_image(f'{dataset_name}/train_images' if self.is_train else f'{dataset_name}/test', image_path)
        return self.transform(image), cultivar, label

In [None]:
SorghumDataset(df_train, is_train=True, transform=np.array).df.head(15)

# Transforming

In [None]:
%%time
from torchvision import transforms as T

mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
normalize = T.Normalize(mean.tolist(), std.tolist())
unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

transform_train = T.Compose([
    ImageConv.np2tensor,
    T.Resize(512),
    T.RandomRotation(45, fill=(82, 83, 65)),
    T.CenterCrop(384),
    T.RandomCrop(320),
    T.Resize(Hypers.image_size),
    normalize,
])

dataset = SorghumDataset(df_train, is_train=True, transform=transform_train)
dataloader = DataLoader(dataset, batch_size=16)

images, _, labels = next(iter(dataloader))
image_gridplot(images[:12], transform=T.Compose([
    ImageConv.to_plt,
]))
image_gridplot(images[:12], transform=T.Compose([
    unnormalize,
    ImageConv.to_plt,
]))

del dataset
del dataloader
del images
del labels

# Define Model

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        # output = (input - kernel + 2*padding) / stride + 1

        self.layer = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),    # (3, 128, 128) -> (16, 128, 128)
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 128, 5, stride=3),   # (16, 128, 128) -> (128, 42, 42)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=1),         # (128, 42, 42) -> (128, 40, 40)

            nn.Conv2d(128, 256, 3),            # (128, 40, 40) -> (256, 38, 38)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3),            # (128, 38, 38) -> (256, 36, 36)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                # (256, 36, 36) -> (256, 18, 18)

            nn.Conv2d(256, 512, 3),            # (256, 20, 20) -> (512, 16, 16)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                # (512, 16, 16) -> (512, 8, 8)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 8 * 8, 1024),
            nn.ReLU(),
#             nn.Dropout(0.2),
            nn.Linear(1024, 100),
        )
    
    def forward(self, x):
        x = self.layer(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# Define Trainer

In [None]:
from datetime import datetime

class Trainer:
    def __init__(
        self,
        loss_function,
        optimizer,
        max_epochs,
        accelerator='cpu',
        min_epochs=1,
        progress=False,
        batch_parser=None,
        logging=True,
        verbose=False,
        callbacks=[],
        **kwargs
    ):
        self.log = {}
        self.criterion = loss_function
        self.optimizer = optimizer
        self.device = accelerator
        self.min_epochs = min_epochs
        self.max_epochs = max_epochs
        self.has_progress = progress
        self.has_logging = logging
        self.verbose = verbose
        self.pbar = None
        self.batch_parser = batch_parser if batch_parser else Trainer.default_parser
        self.start_time = datetime.now()
        self.callbacks = callbacks

    def default_parser(batch):
        x, y = batch
        return x, y
    
    def set_parser(self, function):
        self.batch_parser = function
        
    def logging(self, name, value):
        if not self.has_logging: pass
        if not name in self.log:
            self.log[name] = []
        self.log[name].append(value)

    def step_loop(self, model, dataloader, is_train=True):
        torch.cuda.synchronize()
        model.to(self.device)
        if is_train:
            model.train()
        else:
            model.eval()
        loss_sum = 0
        acc_sum = 0
        counter = 0
        for batch in dataloader:
            counter += 1
            x, y = self.batch_parser(batch)
            x, y = x.to(self.device), y.to(self.device)
            y_hat = model(x).to(self.device)
            loss = self.criterion(y_hat, y)
            _, y_pred = torch.max(y_hat.data, 1)
            acc = (y_pred == y).sum()
            loss_sum += loss.item()
            acc_sum += acc.item()
            if is_train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            if type(None) != type(self.pbar):
                self.pbar.update(1)
        loss_sum /= counter
        acc_sum /= counter
        return loss_sum, acc_sum
    
    def train(self, model, dataloader):
        return self.step_loop(model, dataloader, True)
    
    def valid(self, model, dataloader):
        return self.step_loop(model, dataloader, True)
    
    def test(self, model, dataloader):
        return self.step_loop(model, dataloader, False)

    def fit(self, model, load_dataloaders):
        if self.has_progress:
            self.progress_start().set_description('on ready to fit')
        self.start_time = datetime.now()
        
        # torch.cuda.empty_cache()
        garbage.collect()
        
        for cb in self.callbacks:
            if issubclass(type(cb), CheckpointCallback):
                cb.before_fit(model)

        for e in range(self.max_epochs):
            train_loader, valid_loader = load_dataloaders()
        
            if self.has_progress:
                self.pbar.reset(len(train_loader) + len(valid_loader))
                self.pbar.set_description(f'[epoch={e+1}/{self.max_epochs}] train')
                
            train_loss, train_acc = self.train(model, train_loader)
            self.logging('train_loss', train_loss)
            self.logging('train_acc', train_acc)
            
            if self.has_progress:
                self.pbar.set_description(f'[epoch={e+1}/{self.max_epochs}] valid')
                
            valid_loss, valid_acc = self.valid(model, valid_loader)
            self.logging('valid_loss', valid_loss)
            self.logging('valid_acc', valid_acc)
            
            time_elapsed = (datetime.now() - self.start_time)
            msg = 'train/valid loss: {:.05f}/{:.05f}'.format(train_loss, valid_loss)
            
            if self.has_progress: 
                self.pbar.set_postfix_str(msg)
            
            if self.verbose:
                print(f'[Trainer(epoch={e+1})]', str(time_elapsed), msg)

            early_stop = False
            for cb in self.callbacks:
                if issubclass(type(cb), CheckpointCallback):
                    cb.on_train_end(model, np.mean([train_loss, valid_loss]))
                elif issubclass(type(cb), EarlyStopCallback):
                    early_stop |= cb.on_train_end(e, time_elapsed.seconds, self.log)
                    if early_stop: break
            
            if e < self.min_epochs:
                continue

            if early_stop:
                break
                
            garbage.collect()

        self.progress_end()
    
    def progress_start(self, total=None, desc=None):
        if type(None) == type(self.pbar):
            self.pbar = tqdm_nb()
        self.pbar.reset(total)
        self.pbar.set_description(desc)
        return self.pbar

    def progress_end(self):
        self.pbar.close()
        self.pbar = None

# Define Callback class

In [None]:
class Callback:
    def __init__(self, verbose=False):
        self.verbose = verbose
    def before_fit(self):
        pass
    def message(self, msg, prefix=''):
        if self.verbose:
            print(f'[{self.__class__.__name__}{prefix}] {msg}')
    def on_train_end(self, epoch: int, time_elapsed: int):
        pass

class EarlyStopCallback(Callback):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def on_train_end(self, epoch: int, time_elapsed: int, info: dict):
        pass

class CheckpointCallback(Callback):
    def __init__(
        self,
        name,
        path='./',
        metric=np.less,
        min_delta=0.0,
        load_from_checkpoint=False, # str: filepath
        **kwargs
    ):
        super().__init__(**kwargs)
        self.name = name
        self.path = path
        self.filepath = os.path.join(self.path, self.name)
        self.min_delta = min_delta
        self.metric = metric
        self.load_from_checkpoint = load_from_checkpoint
        self.best = np.Inf
    
    def load_model(self, model: nn.Module, filepath: str):
        file_exists = os.path.exists(filepath)
        if self.verbose:
            if file_exists: self.message(f'model loaded from {self.filepath}')
            else: self.message(f'Failed to load model from {self.filepath}')
        if file_exists:
            model.load_state_dict(torch.load(self.filepath))

    def before_fit(self, model: nn.Module):
        lfc = self.load_from_checkpoint
        if isinstance(lfc, bool):
            if lfc == True:
                self.load_model(model, self.filepath)
        elif isinstance(lfc, str):
            self.load_model(model, lfc)
        else:
            self.message('Disallowed checkpoint parameter:', type(lfc), lfc)

    # update and save model if loss is less than before
    def on_train_end(self, model: nn.Module, loss: float):
        if self.metric(self.best, loss + self.min_delta): return None
        self.message(f'update best: {loss}', prefix=f'({self.filepath})')
        self.best = loss
        torch.save(model.state_dict(), self.filepath)

In [None]:
class MyEarlyStopCallback(EarlyStopCallback):
    def __init__(
        self,
        patience=5,
        metric=np.less,
        min_delta=0.0,
        max_seconds=math.inf,
        last_k_epochs=3,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.patience = patience
        self.min_delta = min_delta
        self.metric = metric
        self.max_seconds = max_seconds
        self.last_k_epochs = last_k_epochs
        self.bad_epochs = 0

    # diff two gradient of mean
    def diff(self, losses):
        k = self.last_k_epochs
        if k + 1 > len(losses): return False
        before, after = losses[-k-1:-1], losses[-k:]
        before, after = np.diff(before).mean(), np.diff(after).mean()
        # return self.metric(before, after + self.min_delta)
        return self.metric(after, self.min_delta)
    
    def on_train_end(self, epoch: int, time_elapsed: int, info: dict):
        super().on_train_end(epoch, time_elapsed, info)
        if time_elapsed >= self.max_seconds: return True
        return False
        if self.diff(info['train_loss']) or self.diff(info['valid_loss']):
            self.bad_epochs += 1
        else:
            self.bad_epochs = 0
        self.message(f'remain {self.patience - self.bad_epochs} patience(s)')
        return self.bad_epochs >= self.patience

## K-Fold Loader

In [None]:
from torch.utils.data import SubsetRandomSampler

class KFoldDataLoader:
    def __init__(self, dataset, batch_size=16, **kwargs):
        self.dataset = dataset
        self.kwargs = kwargs
        self.batch_size = batch_size
        self.reset(**kwargs)
    
    def reset(self, **kwargs):
        self.generator = StratifiedKFold(**kwargs)
        self.splits = list(self.generator.split(self.dataset.data, self.dataset.targets))
        self.current = 0
    
    def fetch(self):
        if self.current >= self.generator.get_n_splits():
            self.reset(**self.kwargs)
        train_idx, valid_idx = self.splits[self.current]
        ##### FOR DEBUG
        rate = 0.75
        t_count = int(len(train_idx) * rate)
        v_count = int(len(valid_idx) * rate)
        train_idx = train_idx[::len(train_idx)//t_count]
        valid_idx = valid_idx[::len(valid_idx)//v_count]
        print('Image count: {} trains, {} valids'.format(len(train_idx), len(valid_idx)))
        ##### FOR DEBUG
        self.current += 1
        train_subsampler = SubsetRandomSampler(train_idx)
        valid_subsampler = SubsetRandomSampler(valid_idx)
        train_loader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=train_subsampler)
        valid_loader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=valid_subsampler)
        return train_loader, valid_loader

# Create Dataset and Fold to train

In [None]:
dataset = SorghumDataset(df_train, is_train=True, transform=transform_train)
dataloader_kfold = KFoldDataLoader(dataset, batch_size=Hypers.batch_size, n_splits=4, shuffle=True)
# train_loader, valid_loader = dataloader_kfold.fetch()

## how long does it take to load all batches

In [None]:
# %%time
# for title, loader in [('train', train_loader), ('valid', valid_loader)]:
#     labels = []
#     for batch in loader:
#         images, cultivars, _ = batch
#         labels.extend(cultivars)
#     df1 = pd.DataFrame(labels, columns=['cultivars'])
#     df2 = df1['cultivars'].value_counts().sort_index()
#     print('Count of classes:', len(df2))
#     print('Skew: ', df2.skew())

#     plt.figure(figsize=(20, 4))
#     plt.title(f'Distribution by {title} loader')
#     plt.xticks(rotation=90)
#     sns.barplot(x=df2.index, y=df2.values)
#     plt.show()

# Train

In [None]:
# Clear CUDA memory
# torch.cuda.empty_cache()

In [None]:
# from torch.utils.data import random_split

# model = Classifier().to(Hypers.device)
# print(model)

### DenseNet 161

In [None]:
densenet161 = torch.hub.load('pytorch/vision:v0.10.0', 'densenet161', pretrained=True)

# Freeze our feature parameters
for param in densenet161.parameters():
    param.requires_grad = False

model = densenet161.to(Hypers.device)
model.classifier = nn.Linear(2208, 100)

print(model)

### EfficientNet

In [None]:
# efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)

# # Freeze our feature parameters
# for param in efficientnet.parameters():
#     param.requires_grad = False

# model = efficientnet.to(Hypers.device)
# model.classifier.fc = nn.Linear(1280, 100)

# print(model)

In [None]:
print('Parameters to learn:')
parameters_learn = []
for name, param in model.named_parameters():
    if param.requires_grad == True:
        parameters_learn.append(param)
        print('    ',name)

In [None]:
def parser_train(batch):
    image, cultivar, label = batch
    return image, label


args = {
    'accelerator': Hypers.device,
    'loss_function': nn.CrossEntropyLoss(),
    'optimizer': Hypers.optimizer(parameters_learn, lr=Hypers.learning_rate),
    'min_epochs': 2,
    'max_epochs': Hypers.max_epoches,
    'progress': True,
    'batch_parser': parser_train,
    'verbose': Hypers.verbose,
    'callbacks': [
        MyEarlyStopCallback(
            patience=Hypers.patience,
            min_delta=-1, # gradient to be added
            max_seconds=5 * 3600,
            verbose=True),
#         CheckpointCallback(
#             name='sorghum_model.pth',
#             min_delta=0.05,
#             load_from_checkpoint=True,
#             verbose=True),
    ]
}

print(args)

trainer = Trainer(**args)
trainer.fit(model, dataloader_kfold.fetch)

# Visualization

In [None]:
logging_df = pd.DataFrame(trainer.log)

plt.figure(figsize=(10, 8))
sns.lineplot(data=logging_df[['train_loss', 'valid_loss']])
plt.show()

plt.figure(figsize=(10, 8))
sns.lineplot(data=logging_df[['train_acc', 'valid_acc']])
plt.show()

# Submit

In [None]:
df_submit = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
df_submit.head(5)

transform_test = T.Compose([
#     CLAHE,
    ImageConv.np2tensor,
    T.Resize(Hypers.image_size),
    normalize,
])

test_dataset = SorghumDataset(df_submit, is_train=False, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=Hypers.batch_size, shuffle=False)

images, _, _ = next(iter(test_loader))
image_gridplot(images[:12], transform=T.Compose([
    unnormalize,
    ImageConv.to_plt,
]))

In [None]:
def parser_test(batch):
    image, label, filename = batch
    return image, label, filename

# TODO: implement
# trainer.test(model, test_loader)

result = {
    'filename': [],
    'cultivar': [],
}
for batch in tqdm_nb(test_loader, desc='Get predictions'):
    images, l, filenames = parser_test(batch)
    outputs = model(images.to(Hypers.device)).detach().cpu()
    _, preds = torch.max(outputs.data, 1)
    preds_decode = dataset.encoder.inverse_transform(preds)
    result['filename'].extend(filenames)
    result['cultivar'].extend(preds_decode)

result_df = pd.DataFrame(result)
result_df.to_csv('submission.csv', index=False)
result_df