# References
1. [Plant Pathology with Lightning ⚡ : by Jirka](https://www.kaggle.com/jirkaborovec/plant-pathology-with-lightning)
2. [Plant Pathology - PyTorch Lightning ⚡️: by Aniket](https://www.kaggle.com/aniketmaurya/plant-pathology-pytorch-lightning/comments) 

Installation (you might have to restart the kernel)
```
!pip install -U 'lightning-flash[image]'==0.5.0rc0 -q
!pip install -U torchvision
!pip install -U torchtext
```

In [None]:
!pip install -U 'lightning-flash[image]'==0.5.0rc0 -q
!pip install -U torchvision
!pip install -U torchtext

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer

import flash
from flash.image import ImageClassificationData, ImageClassifier

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.metrics import FBeta
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger


import torch
import torchmetrics
import torchvision
from torch import nn
from torch.nn import functional as F

import os
from glob import glob
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Dataset

from PIL import Image

from torchvision import transforms
from pathlib import Path

In [None]:
data_dir = Path("/kaggle/input/plant-pathology-2021-fgvc8/")

In [None]:
df = pd.read_csv(data_dir/'train.csv')
df['label_org'] = df.labels.values
df.labels = df.labels.str.split()

df.head()

In [None]:
# ref: Jirka
import itertools
import seaborn as sns

labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in df['label_org']]))

ax = sns.countplot(y=sorted(labels_all), orient='v')
ax.grid()

In [None]:
BS = 32
IMAGE_SIZE = 128

In [None]:
# ref: Jirka
from torchvision import transforms as T

TRAIN_TRANSFORM = T.Compose([
    T.Resize(256),
    T.RandomPerspective(),
    T.RandomResizedCrop(IMAGE_SIZE),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

VALID_TRANSFORM = T.Compose([
    T.Resize(256),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

TEST_TRANSFORM = T.Compose([
    T.Resize(256),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [None]:
mlb = MultiLabelBinarizer(sparse_output=True)
mlb = mlb.fit(df.labels)
def create_ohe(df, mlb):    
    ohe = mlb.transform(df.labels)
    ohe = pd.DataFrame.sparse.from_spmatrix(ohe, columns=mlb.classes_)
    df = df.merge(ohe, left_index=True, right_index=True)
    return df
df = create_ohe(df, mlb)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df.head()

In [None]:
split = 0.9
frac = int(split * len(df))

train_data = df[:frac]
val_data = df[frac:]

train_data = train_data.sample(frac=1, random_state=42).reset_index(drop=True)
val_data = val_data.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
class PlantDataset(Dataset):
    def __init__(self, data, transformation, folder='train'):
        self.data = data
        self.transform = transformation
        self.folder = folder
    
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        folder = self.folder
        file = data_dir/f"{folder}_images/{self.data.loc[idx, 'image']}"
        image = Image.open(file)
        if self.transform:
            image = self.transform(image)
        labels = self.data.iloc[idx, 3:].to_numpy().astype(int)
        return {"input": image, "target": labels}

In [None]:
train_dataset = PlantDataset(train_data, TRAIN_TRANSFORM)
val_dataset = PlantDataset(val_data, VALID_TRANSFORM)

In [None]:
import multiprocessing as mproc
import pytorch_lightning as pl

class PlantPathologyDM(pl.LightningDataModule):

    def __init__(
        self,
        train_dataset: Dataset = None,
        val_dataset: Dataset = None,
        batch_size: int = 64,
        num_workers: int = None,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers if num_workers is not None else mproc.cpu_count()
        self.train_dataset = train_dataset
        self.valid_dataset = val_dataset

    def prepare_data(self):
        pass

    @property
    def num_classes(self) -> int:
        return num_classes

    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True
        )

    def test_dataloader(self):
        pass

In [None]:
dm = PlantPathologyDM(train_dataset, val_dataset)


In [None]:
# # quick view
# fig = plt.figure(figsize=(3, 7))
# for data in dm.train_dataloader():
#     imgs = data["input"]
#     lbs = data["target"]
#     print(f'batch labels: {torch.sum(lbs, axis=0)}')
#     print(f'image size: {imgs[0].shape}')
#     for i in range(3):
#         ax = fig.add_subplot(3, 1, i + 1, xticks=[], yticks=[])
#         # print(np.rollaxis(imgs[i].numpy(), 0, 3).shape)
#         ax.imshow(np.rollaxis(imgs[i].numpy(), 0, 3))
#         ax.set_title(lbs[i])
#     break

In [None]:
labels = []
i = 0
for label in tqdm(df.labels):
    labels.extend(label)
labels = set(labels)
num_classes = len(labels)
labels

In [None]:
def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
    return F.binary_cross_entropy_with_logits(x, y.float())

In [None]:
model = ImageClassifier(
    dm.num_classes,
    'ssl_resnet50',
    loss_fn=binary_cross_entropy_with_logits,
    multi_label=True
)

In [None]:
# model.serializer = Labels(labels, multi_label=True, threshold=0.25)

Nvidia Pytorch tips

In [None]:
model = model.to('cuda')

In [None]:
trainer = flash.Trainer(
    max_epochs=10,
    auto_lr_find=True,
    benchmark=True,
    gpus=1,
)

In [None]:
trainer.finetune(model, datamodule=dm, strategy="freeze_unfreeze")

In [None]:
submission_df = pd.read_csv(data_dir/'sample_submission.csv')
# submission_df.labels = None
submission_df.head()

In [None]:
submission_dataset = PlantDataset(submission_df, TEST_TRANSFORM, 'test')
submission_dataloader = DataLoader(submission_dataset, 16, num_workers=4)

In [None]:
model = model.eval()

TODO: create submission data

In [None]:
@torch.no_grad()
def get_results(submission_dataloader):
    results = []
    for data in submission_dataloader:
        image = data['input']
        preds = model(image)
        preds = (preds.sigmoid() > 0.5)

        for pred in preds:
            lab = (df.columns[3:][pred])
            results.append(lab.tolist())
    return results

In [None]:
submission_df.labels = get_results(submission_dataloader)
submission_df.labels = submission_df.labels.apply(lambda x: " ".join(x))
submission_df.to_csv("/kaggle/working/results.csv")