In [None]:
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

# Plant2021 - PyTorch - Submission

# Overview

* Plant Pathology 2021 Competition
* Use pretrained PyTorch ResNet model
* Multi-label classification


The trained model was designed in the Norbook *Plant2021 - PyTorch - ResNet*. 

## Imports

In [None]:
from typing import List, Dict

import random
import os

import numpy as np
import pandas as pd
import PIL

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torchvision
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms as T
from torchvision.transforms import functional as F

import skimage.io as io
import skimage.feature
from skimage import color
from skimage import segmentation

from tqdm.notebook import tqdm

## Configuration

In [None]:
import torch
print(torch.__version__)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rc('font', size=15)
plt.rc('axes', titlesize=18)  
plt.rc('xtick', labelsize=10)  
plt.rc('ytick', labelsize=10)

In [None]:
class Config: 
    """
    """
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    INPUT_PATH = '../input/plant-pathology-2021-fgvc8'
    OUTPUT_PATH = './'
    BATCH_SIZE = 64
    RANDOM_STATE = 2021
    SAMPLE_FRAC = 0.01
    IMG_SIZE = 224
    TRAIN_DATA_FILE = os.path.join(INPUT_PATH, 'train.csv')
    SAMPLE_SUBMISSION_FILE = os.path.join(INPUT_PATH, 'sample_submission.csv')
    SUBMISSION_FILE = os.path.join(OUTPUT_PATH, 'submission.csv')
    MODEL_FILE = f'../input/plant2021-pytorch-resnet/plant2021_{DEVICE}.pth'
    CLASSES = [
        'rust', 
        'complex', 
        'healthy', 
        'powdery_mildew', 
        'scab', 
        'frog_eye_leaf_spot'
    ]
    N_CLASSES = len(CLASSES)
    CLASS_THRESHOLD = 0.3
    
    folders = dict({
        'data': INPUT_PATH,
        'train':  os.path.join(INPUT_PATH, 'train_images'),
        'test': os.path.join(INPUT_PATH, 'test_images')
    })
    
    @staticmethod
    def set_seed():
        torch.manual_seed(Config.RANDOM_STATE)
        random.seed(Config.RANDOM_STATE)
        np.random.seed(Config.RANDOM_STATE)
        
Config.set_seed()        

In [None]:
print(f'Using {Config.DEVICE} device.')

In [None]:
def to_numpy(tensor):
    """Auxiliary function to convert tensors into numpy arrays
    """
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

## Load images labels

In [None]:
def read_image_labels():
    """
    """
    df = pd.read_csv(Config.TRAIN_DATA_FILE).set_index('image')
    return df

In [None]:
img_labels = read_image_labels().sample(
    frac=Config.SAMPLE_FRAC, 
    random_state=Config.RANDOM_STATE
)

img_labels.head()

In [None]:
image_ids = pd.read_csv(Config.SAMPLE_SUBMISSION_FILE).set_index('image')
image_ids

## Label distribution

In [None]:
def get_image_infos(img_labels):
    """
    """
    df = img_labels.reset_index().groupby(by='labels').count().reset_index()
    df.columns = ['disease', 'count']
    
    df['%'] = np.round((df['count'] / img_labels.shape[0]), 2) * 100
    df = df.set_index('disease').sort_values(by='count', ascending=False)

    return df

In [None]:
get_image_infos(img_labels)

In [None]:
img_labels.head()

## One hot encoding

In [None]:
def get_single_labels(unique_labels) -> List[str]:
    """Splitting multi-labels and returning a list of classes"""
    single_labels = []
    
    for label in unique_labels:
        single_labels += label.split()
        
    single_labels = set(single_labels)
    return list(single_labels)

In [None]:
def get_one_hot_encoded_labels(dataset_df) -> pd.DataFrame:
    """
    """
    df = dataset_df.copy()
    
    unique_labels = df.labels.unique()
    column_names = get_single_labels(unique_labels)
    
    df[column_names] = 0        
    
    # one-hot-encoding
    for label in unique_labels:                
        label_indices = df[df['labels'] == label].index
        splited_labels = label.split()
        df.loc[label_indices, splited_labels] = 1
    
    return df

In [None]:
one_hot_encoded_labels = get_one_hot_encoded_labels(img_labels)
one_hot_encoded_labels.head()

## Visualization of images

In [None]:
def get_image(image_id, kind='train'):
    """Loads an image from file
    """
    fname = os.path.join(Config.folders[kind], image_id)
    return PIL.Image.open(fname)

In [None]:
def visualize_images(image_ids, labels, nrows=1, ncols=4, kind='train', image_transform=None):
    """
    """
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 8))
    for image_id, label, ax in zip(image_ids, labels, axes.flatten()):
        
        fname = os.path.join(Config.folders[kind], image_id)
        image = np.array(PIL.Image.open(fname))
        
        if image_transform:
            image = transform = A.Compose(
                [t for t in image_transform.transforms if not isinstance(t, (
                    A.Normalize, 
                    ToTensorV2
                ))])(image=image)['image']
        
        io.imshow(image, ax=ax)
        
        ax.set_title(f"Class: {label}", fontsize=12)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        del image
        
    plt.show()

In [None]:
visualize_images(img_labels.index, img_labels.labels, nrows=2, ncols=4)

## Augmentation pipeline

In [None]:
image_transfom = A.Compose([
    A.Resize(
        height=Config.IMG_SIZE,
        width=Config.IMG_SIZE,
    ),
    A.Normalize(
        mean=(0.485, 0.456, 0.406), 
        std=(0.229, 0.224, 0.225)
    ),
    ToTensorV2(),
])

In [None]:
images = img_labels.sample(n=5)

visualize_images(
    images.index, 
    images.labels, 
    nrows=1,
    ncols=5,
    image_transform=image_transfom
)

## Database

In [None]:
from scipy.stats import bernoulli
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PlantDataset(Dataset):
    """
    """
    def __init__(self, 
                 image_ids, 
                 targets,
                 transform=None, 
                 target_transform=None, 
                 kind='train'):
        self.image_ids = image_ids
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform
        self.kind = kind
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        # load and transform image
        img = np.array(get_image(self.image_ids.iloc[idx], kind=self.kind))
        
        if self.transform:
            img = self.transform(image=img)['image']
        
        # get image target 
        target = self.targets[idx]
        if self.target_transform:
            target = self.target_transform(target)
        
        return img, target

In [None]:
X_val = pd.Series(img_labels.index)
y_val = np.array(one_hot_encoded_labels[Config.CLASSES])

In [None]:
val_set = PlantDataset(X_val, y_val, transform=image_transfom, kind='train')
val_loader = DataLoader(val_set, batch_size=Config.BATCH_SIZE, shuffle=True)

## Create model and load weights

In [None]:
def load_weights(model, load_path=Config.MODEL_FILE):
    model.load_state_dict(torch.load(load_path))
    model.eval()

def create_model(pretrained=False):
    model = torchvision.models.resnet50(pretrained=pretrained).to(Config.DEVICE)
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(
            in_features=model.fc.in_features,
            out_features=Config.N_CLASSES
        ),
        torch.nn.Sigmoid()
    ).to(Config.DEVICE)
    
    return model

In [None]:
model = create_model(pretrained=False).to(Config.DEVICE);
load_weights(model)

## Confusion matrix

In [None]:
def predict(model, loader):
    y_true = np.empty(shape=(0, 6), dtype=np.int)
    y_pred_proba = np.empty(shape=(0, 6), dtype=np.int)

    stream = tqdm(loader)
    for batch, (X, y) in enumerate(stream, start=1):
        X = X.to(Config.DEVICE)
        y = to_numpy(y.to(Config.DEVICE))
        pred = to_numpy(model(X))

        y_true = np.vstack((y_true, y))
        y_pred_proba = np.vstack((y_pred_proba, pred))
        
    return y_true, y_pred_proba

In [None]:
y_true, y_pred_proba = predict(model, val_loader)

In [None]:
from sklearn.metrics import multilabel_confusion_matrix

def plot_confusion_matrix(
    y_test, 
    y_pred_proba, 
    threshold=Config.CLASS_THRESHOLD, 
    label_names=Config.CLASSES
)-> None:
    """
    """
    y_pred = np.where(y_pred_proba > threshold, 1, 0)
    c_matrices = multilabel_confusion_matrix(y_test, y_pred)
    
    cmap = plt.get_cmap('Blues')
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))

    for cm, label, ax in zip(c_matrices, label_names, axes.flatten()):
        sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap=cmap);

        ax.set_xlabel('Predicted labels');
        ax.set_ylabel('True labels'); 
        ax.set_title(f'{label}');

    plt.tight_layout()    
    plt.show()

In [None]:
plot_confusion_matrix(y_true, y_pred_proba)

## Submission

In [None]:
def save_submission(model):
    """
    """
    image_ids = pd.read_csv(Config.SAMPLE_SUBMISSION_FILE)
    
    dataset = PlantDataset(
        image_ids['image'], 
        image_ids['labels'], 
        transform=image_transfom, 
        kind='test'
    )
    
    loader = DataLoader(dataset)

    for idx, (X, _) in enumerate(loader):
        X = X.float().to(Config.DEVICE)
        y_pred = to_numpy(torch.argmax(model(X), dim=1))

        pred_labels = ' '.join([Config.CLASSES[i] for i in y_pred]).strip()
        image_ids.iloc[idx]['labels'] = pred_labels
    
    # save data frame as csv
    image_ids.set_index('image', inplace=True)
    image_ids.to_csv(Config.SUBMISSION_FILE)
    
    return image_ids

In [None]:
save_submission(model)   