In [None]:
from typing import List, Dict

import random
import os

import numpy as np
import pandas as pd
import PIL
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms as T

import skimage.io as io
from tqdm.notebook import tqdm

In [None]:
submission_df = pd.read_csv('/kaggle/input/plant-pathology-2021-fgvc8/sample_submission.csv').set_index('image')
# submission_df.labels = None
submission_df.head()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [None]:
def read_image_labels():
    df = pd.read_csv('../input/plant-pathology-2021-fgvc8/sample_submission.csv').set_index('image')
    return df

In [None]:
img_labels = read_image_labels()

img_labels.head()

In [None]:
def get_image_infos(img_labels):
  
    df = img_labels.reset_index().groupby(by='labels').count().reset_index()
    df.columns = ['disease', 'count']
    
    df['%'] = np.round((df['count'] / img_labels.shape[0]), 2) * 100
    df = df.set_index('disease').sort_values(by='count', ascending=False)

    return df

In [None]:
folders = dict({
        'data': '../input/plant-pathology-2021-fgvc8',
        'train': '../input/resized-plant2021/img_sz_256',
        'val': '../input/resized-plant2021/img_sz_256',
        'test':  '../input/plant-pathology-2021-fgvc8/test_images',
        'submiss': '../input/plant-pathology-2021-fgvc8/sample_submission.csv'
    })

def get_image(image_id, kind='train'):
    """Loads an image from file
    """
    fname = os.path.join(folders[kind], image_id)
    return PIL.Image.open(fname)

def plot_image_counts(img_labels):
    fig, ax = plt.subplots(figsize=(18, 7))
    sns.set_style("whitegrid")
    palette = sns.color_palette("Blues_r", 12)

    sns.countplot(
        x='labels', 
        palette=palette,
        data=img_labels,
        order=img_labels['labels'].value_counts().index,
    );

    plt.ylabel("# of observations", size=20);
    plt.xlabel("Class names", size=20)

    plt.xticks(rotation=45)
    
    fig.tight_layout()
    plt.show()

In [None]:
def get_single_labels(unique_labels) -> List[str]:
    single_labels = []
    
    for label in unique_labels:
        single_labels += label.split()
        
    single_labels = set(single_labels)
    return list(single_labels)

In [None]:
def get_one_hot_encoded_labels(dataset_df) -> pd.DataFrame:
    df = dataset_df.copy()
    
    unique_labels = ['rust', 
        'complex', 
        'healthy', 
        'powdery_mildew', 
        'scab', 
        'frog_eye_leaf_spot']
    column_names = get_single_labels(unique_labels)
    
    df[column_names] = 0        
    # one-hot-encoding
    for label in unique_labels:                
        label_indices = df[df['labels'] == label].index
        splited_labels = label.split()
        df.loc[label_indices, splited_labels] = 1
    
    return df

In [None]:
one_hot_encoded_labels = get_one_hot_encoded_labels(img_labels)
one_hot_encoded_labels.head()

In [None]:
test_transform = A.Compose([
    A.Resize(
        height=224,
        width=224,
    
    ),
A.Normalize(
        mean=(0.485, 0.456, 0.406), 
        std=(0.229, 0.224, 0.225)
    ),    
    ToTensorV2(),
])

In [None]:
from scipy.stats import bernoulli
from torch.utils.data import Dataset

class PlantDataset(Dataset):
    """
    """
    def __init__(self, 
                 image_ids, 
                 targets,
                 transform=None, 
                 target_transform=None, 
                 kind='train'):
        self.image_ids = image_ids
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform
        self.kind = kind
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        # load and transform image
        img = np.array(get_image(self.image_ids.iloc[idx], kind=self.kind))
        
        if self.transform:
            img = self.transform(image=img)['image']
        
        # get image target 
        target = self.targets[idx]
        if self.target_transform:
            target = self.target_transform(target)
        
        return img, target

In [None]:

X_test = pd.Series(submission_df.index)
y_test = np.array(one_hot_encoded_labels[[
        'rust', 
        'complex', 
        'healthy', 
        'powdery_mildew', 
        'scab', 
        'frog_eye_leaf_spot'
    ]])
print(len(y_test))

In [None]:
test_set = PlantDataset(X_test, y_test, transform=test_transform, kind='test')

batch_size = 32
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
X_test

In [None]:
def create_model(pretrained=False):
    model = torchvision.models.resnet101(pretrained=pretrained).to(device)
    
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(
            in_features=model.fc.in_features,
            out_features=6
        ),
        torch.nn.Sigmoid()
    ).to(device)
    
    return model

In [None]:
model = create_model()

In [None]:
load_path = '../input/resnetv5/v5.pkl'
model.load_state_dict(torch.load(load_path))

In [None]:
batch = 32

y_true = np.empty(shape=(0, 6), dtype=np.int)
y_pred_proba = np.empty(shape=(0, 6), dtype=np.int)

stream = tqdm(test_loader)
for batch, (X, y) in enumerate(stream):
    X = X.float().to(device)
    y = to_numpy(y.to(device))
    pred = to_numpy(model(X))
    
    y_true = np.vstack((y_true, y))
    y_pred_proba = np.vstack((y_pred_proba, pred))

In [None]:
y_pred_proba

In [None]:
y_pred_proba = y_pred_proba.tolist()
indices =  []
for pred in y_pred_proba:
    temp = []
    for category in pred:
        if category >= 0.25:
            print(category)
            temp.append(pred.index(category))
    if temp!=[]:
        indices.append(temp)
    else:
        temp.append(np.argmax(pred))
        indices.append(temp)
    
print(indices)

In [None]:
labels =  [
        'rust', 
        'complex', 
        'healthy', 
        'powdery_mildew', 
        'scab', 
        'frog_eye_leaf_spot'
    ]
testlabels = []


for image in indices:
    temp = []
    for i in image:
        temp.append(str(labels[i]))
    testlabels.append(' '.join(temp))

print(testlabels)

In [None]:
sub = pd.read_csv('../input/plant-pathology-2021-fgvc8/sample_submission.csv')
sub['labels'] = testlabels
sub.to_csv('submission.csv', index=False)