# Plant Pathology 2021 - FGVC8

In [None]:
!pip uninstall -y torchtext
!pip install -q --upgrade torch torchvision
!pip install -q "lightning-flash[image]" "torchmetrics<0.8"
!pip install -q -U timm segmentation-models-pytorch

! pip list | grep torch
! pip list | grep lightning
! nvidia-smi -L

## Data exploration

Checking what data do we have available and what is the labels distribution...

In [None]:
%matplotlib inline

import os
import json
import pandas as pd
from pprint import pprint

base_path = '/kaggle/input/plant-pathology-2021-fgvc8-960px'
path_csv = os.path.join(base_path, 'train.csv')
train_data = pd.read_csv(path_csv)
display(train_data.head())

We can see that each image can have multiple labels so lets check what is the mos common label count...

*The target classes, a space delimited list of all diseases found in the image.
Unhealthy leaves with too many diseases to classify visually will have the complex class, and may also have a subset of the diseases identified.*

In [None]:
import numpy as np

train_data['nb_classes'] = [len(lbs.split(" ")) for lbs in train_data['labels']]
lb_hist = dict(zip(range(10), np.bincount(train_data['nb_classes'])))
pprint(lb_hist)

Browse the label distribution, enrolling all labels in the dataset, so in case an image has two labels both are used in this stat...

In [None]:
import itertools
import seaborn as sns

labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in train_data['labels']]))
train_data['labels_sorted'] = [" ".join(sorted(lbs.split(" "))) for lbs in train_data['labels']]

sns.set()
ax = sns.countplot(y=labels_all, orient='v')
ax.grid()

## Flash finetuning

In [None]:
import flash
import torch
import pytorch_lightning as pl
from flash.image import ImageClassificationData, ImageClassifier

## 1. Load the data

In [None]:
datamodule = ImageClassificationData.from_data_frame(
    "image",
    # list(labels_uq),
    "labels",
    train_data_frame=train_data,
    train_images_root=os.path.join(base_path, "train_images"),
    transform_kwargs={"image_size": (384, 384)},
    batch_size=24,
    num_workers=2,
    val_split=0.2,
)
print(datamodule.multi_label)

## 2. Build the model

In [None]:
model = ImageClassifier(
    backbone="tf_efficientnet_b4_ns",
    optimizer=torch.optim.AdamW,
    learning_rate=0.005,
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

## 4. Create the trainer

In [None]:
import pytorch_lightning as pl

logger = pl.loggers.CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    gpus=1,
    logger=logger,
    max_epochs=5,
    precision=16,
    val_check_interval=0.5,
    # limit_train_batches=0.1,
    # limit_val_batches=0.1,
)

## 5. Train the model

In [None]:
# Train the model
trainer.finetune(model, datamodule=datamodule, strategy=('freeze_unfreeze', 1))

# Save it!
trainer.save_checkpoint("image_classification_model.pt")

In [None]:
import matplotlib.pyplot as plt

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sns.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()