# Leaf disk sheat segmentation and disk coordonates detection

## Imports

In [None]:
import os
import sys
from pathlib import Path
import datetime

from collections import OrderedDict

from tqdm.notebook import tqdm

%load_ext autoreload
%autoreload 2

from collections import defaultdict
import copy
import random

import pandas as pd

from sklearn.model_selection import train_test_split

import albumentations as A

from PIL import Image

from sklearn.model_selection import train_test_split

from albumentations.pytorch import ToTensorV2

import cv2
import matplotlib.pyplot as plt
import numpy as np

import ternausnet.models
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torchvision

from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks import RichProgressBar

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks import ModelCheckpoint

import pytorch_lightning as pl

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import Button, HBox, VBox

import panel as pn
import hvplot.pandas
pn.extension()


sys.path.insert(0, os.path.join("..", "scripts"))

import ld_dataset as ldd
import ld_plot as ldp
import ld_image as ldi
import ld_th_pl_lightning as ldpl
import gav_oidium_const as goc


## Build local images dataframe

In [None]:
una_images = [i.stem for i in ldd.una_images_folder.glob("*") if not i.stem.startswith(".")]
ano_images = [i.stem for i in ldd.train_images_folder.glob("*") if not i.stem.startswith(".")]
wot_images = [i.stem for i in ldd.wot_images_folder.glob("*") if not i.stem.startswith(".")]


In [None]:
len(ano_images), len(una_images), len(wot_images)

In [None]:
local_images = list(set(ano_images + una_images + wot_images))
len(local_images)

In [None]:
all_columns = [
    "experiment",
    "rep",
    "image_name",
    "ligne",
    "colonne",
    "oiv",
    "sporulation",
    "densite_sporulation",
    "necrose",
    "surface_necrosee",
    "taille_necrose",
]

df = (
    pd.read_csv(os.path.join("..", "data_in", "raw_merged.csv"), sep=";")
)
df

In [None]:
df = df[df.image_name.str.lower().replace("_", "-").isin([l.lower().replace("_", "-") for l in local_images])]

df.sort_values("image_name").to_csv(
    os.path.join("..", "data_in", "local_raw_merged.csv"),
    sep=";",
    index=False,
)

df


## Prepare envionment

In [None]:
ldpl.g_device


In [None]:
sample_frac = 1

## Check image duplicates

In [None]:
una_images = [str(i) for i in ldd.una_images_folder.glob("*") if not i.stem.startswith(".")]
ano_images = [str(i) for i in ldd.train_images_folder.glob("*") if not i.stem.startswith(".")]

len(una_images), len(ano_images)

#### Check anotated images in not annotated images

In [None]:
[img for img in ano_images if img in una_images]

#### Check not anotated images in annotated images

In [None]:
[img for img in una_images if img in ano_images]

## Images

In [None]:
df_train_images = ldd.build_items_dataframe(
    images_folder=ldd.train_images_folder,
    masks_folder=ldd.train_masks_folder,
)

In [None]:
df_train_images

In [None]:
ldd.check_items_consistency(
    images_folder=ldd.train_images_folder,
    masks_folder=ldd.train_masks_folder,
)

In [None]:
train, test = train_test_split(df_train_images, test_size=0.3, stratify=df_train_images["year"])
test, val = train_test_split(test, test_size=0.5, stratify=test["year"])

print(len(train), len(test), len(val))


In [None]:
df_set_counts = pd.DataFrame(
    index=df_train_images.year.sort_values().value_counts(sort=False).index,
    data={
        "train": train.year.sort_values().value_counts(sort=False).values,
        "val": val.year.sort_values().value_counts(sort=False).values,
        "test": test.year.sort_values().value_counts(sort=False).values,
    },
)
df_set_counts


In [None]:
df_set_counts.sum(axis=0)

In [None]:
ldp.display_images_and_masks_grid(train.sample(n=2), fontsize=10, figsize=(8,6))


In [None]:
ldp.display_images_and_masks_grid(val.sample(n=2), fontsize=10, figsize=(8,6))


In [None]:
ldp.display_images_and_masks_grid(test.sample(n=2), fontsize=10, figsize=(8,6))


## Albumentations

In [None]:
img_width = 32 * 21
img_height = 32 * 14

assert(img_width / img_height == 1.5)

alb_resizer = [A.Resize(height=img_height, width=img_width)]

train_transformers_list = alb_resizer + [
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.RandomGamma(p=0.33),
    A.CLAHE(p=0.33),
]

transformer = A.Compose(train_transformers_list)


image, mask = ldd.open_image_and_mask(
    0, df_train_images.sample(n=1).reset_index(drop=True)
)

transformed = transformer(image=image, mask=mask)

to_tensor = [
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
]

train_transformer = A.Compose(train_transformers_list + to_tensor)

val_transformer = A.Compose(alb_resizer + to_tensor)

test_transformer = A.Compose(alb_resizer + to_tensor)


ldp.visualize_augmentations(
    image=image,
    mask=mask,
    augmentations=[transformer(image=image, mask=mask) for _ in range(5)],
)


## Models

### Lightning

In [None]:
model = ldpl.LeafDiskSegmentation(
    batch_size=16,
    selected_device=ldpl.g_device,
    learning_rate=0.0005,
    max_epochs=400,
    num_workers=0,
    train_augmentations=train_transformer,
    train_data=train,
    val_augmentations=val_transformer,
    val_data=val,
    accumulate_grad_batches=3,
)

trainer = Trainer(
    accelerator="gpu",
    max_epochs=model.max_epochs,
    log_every_n_steps=5,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor="val_loss", mode="min", patience=15, min_delta=0.0005),
        DeviceStatsMonitor(),
        ModelCheckpoint(
            save_top_k=3,
            monitor="val_loss",
            auto_insert_metric_name=True,
            filename="{epoch}-{step}-{train_loss:.5}-{val_loss:.5f}",
        ),
    ],
    accumulate_grad_batches=model.accumulate_grad_batches,
)
trainer.fit(model)


In [None]:
version_overview = ldpl.update_overviews(test) 
version_overview


In [None]:
predictor = ldpl.LeafDiskPredictor()

op_df_versions_overview = widgets.Output()
with op_df_versions_overview:
    display(
        version_overview.drop(["checkpoint_fileName"], axis=1).reset_index(drop=True)
    )

cb_select_version = widgets.Dropdown(
    options=["-1 | Select version"]
    + [
        f"{i} | {fn}"
        for i, fn in enumerate(version_overview.checkpoint_fileName.to_list())
    ],
    description="Select version",
)

cb_image = widgets.Dropdown(
    options=sorted([str(i) for i in ldd.una_images_folder.glob("*") if not i.stem.startswith(".")]),
    description="Select an image:",
)

cb_ancillary_image = widgets.Dropdown(
    options=["Raw mask", "Cleaned mask", "Probabilities"],
    description="Ancillary image",
)

src_image = widgets.Output()
ancillary_image = widgets.Output()
pred_image = widgets.Output()


def update_images(image_path, anc_mode):
    if predictor.model is None:
        src_image.clear_output()
        ancillary_image.clear_output()
        pred_image.clear_output()
        return

    image = ldd.open_image(image_path)

    predicted_mask = predictor.predict_image(image_path)
    clean_mask = ldi.clean_contours(mask=predicted_mask.copy(), size_thrshold=0.75)
    contours = ldi.index_contours(clean_mask)

    anc_img = (
        predicted_mask
        if anc_mode == "Raw mask"
        else clean_mask
        if anc_mode == "Cleaned mask"
        else predictor.predict_image(image_path, return_probabilities=True)
        if anc_mode == "Probabilities"
        else None
    )

    src_image.clear_output()
    with src_image:
        ldp.visualize_image(
            image=image,
            title=Path(image_path).stem,
        )

    ancillary_image.clear_output()
    with ancillary_image:
        plt.imshow(anc_img, cmap=plt.cm.RdPu)
        plt.tight_layout()
        plt.axis("off")
        plt.show()

    pred_image.clear_output()
    with pred_image:
        ldp.visualize_image(
            ldi.print_contours_indexes(
                clean_mask,
                contours,
                canvas=ldi.apply_mask(image, clean_mask, draw_contours=8),
            ),
            figsize=(12, 8),
        )


def on_image_changed(change):
    update_images(change.new, cb_ancillary_image.value)


def on_anc_mode_changed(change):
    update_images(cb_image.value, change.new)


def on_version_changed(change):
    idx, filename = str(change.new).replace(" ", "").split("|")
    if int(idx) >= 0:
        global predictor
        predictor.model = ldpl.LeafDiskSegmentation.load_from_checkpoint(filename)
        update_images(cb_image.value, cb_ancillary_image.value)
    else:
        predictor.model = None


cb_select_version.observe(on_version_changed, names="value")
cb_image.observe(on_image_changed, names="value")
cb_ancillary_image.observe(on_anc_mode_changed, names="value")

display(
    VBox(
        [
            HBox([op_df_versions_overview, cb_select_version]),
            HBox([cb_image, cb_ancillary_image]),
            HBox([VBox([src_image, ancillary_image]), pred_image]),
        ]
    )
)


In [None]:
str(cb_select_version.value).replace(" ", "").split("|")[1]

## Test all available not used images

Create the target folder

In [None]:
data_out_fld = ldd.root_folder.parent.parent.joinpath("data_out", "predictions", datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
data_out_fld.mkdir(parents=True, exist_ok=True)
data_out_fld

In [None]:
for i in tqdm(ldd.una_images_folder.glob("*")):
    if i.stem.startswith("."):
        continue

    image = ldd.open_image(str(i))
    predicted_mask = predictor.predict_image(str(i))
    clean_mask = ldi.clean_contours(mask=predicted_mask.copy(), size_thrshold=0.75)
    contours = ldi.index_contours(clean_mask)

    image = Image.fromarray(
        ldi.print_contours_indexes(
            clean_mask,
            contours,
            canvas=ldi.apply_mask(image, clean_mask, draw_contours=8),
        )
    )
    image.save(str(data_out_fld.joinpath(i.name)))


## Retrieve isolated leaf disk

In [None]:
cb_src_image = widgets.Dropdown(
    options=sorted(
        [str(i) for i in ldd.una_images_folder.glob("*") if not i.stem.startswith(".")]
    ),
    description="Select a leaf dislk:",
)
cb_row = widgets.Dropdown(options=["A", "B", "C"])
cb_col = widgets.Dropdown(options=[1, 2, 3, 4])

op_full_plate = widgets.Output()
op_leafdisk = widgets.Output()
op_leafdisk_no_bck = widgets.Output()


def show_leaf_disk(image_path, row, col):
    image = ldd.open_image(image_path)

    predicted_mask = predictor.predict_image(image_path)
    clean_mask = ldi.clean_contours(mask=predicted_mask.copy(), size_thrshold=0.75)
    contours = ldi.index_contours(clean_mask)

    op_full_plate.clear_output()
    op_leafdisk.clear_output()
    op_leafdisk_no_bck.clear_output()

    with op_full_plate:
        ldp.visualize_image(
            ldi.print_single_contour(
                clean_mask,
                contours,
                row=row,
                col=col,
                canvas=ldi.apply_mask(image, clean_mask, draw_contours=8),
            ),
            figsize=(12, 8),
        )

    with op_leafdisk:
        ldp.visualize_image(ldi.get_leaf_disk(image.copy(), contours, row, col))

    with op_leafdisk_no_bck:
        ldp.visualize_image(
            ldi.get_leaf_disk(
                image.copy(),
                contours,
                row,
                col,
                mask=clean_mask,
            )
        )


def on_ld_image_changed(change):
    show_leaf_disk(change.new, cb_row.value, cb_col.value)


def on_ld_row_changed(change):
    show_leaf_disk(cb_src_image.value, change.new, cb_col.value)


def on_ld_col_changed(change):
    show_leaf_disk(cb_src_image.value, cb_row.value, change.new)


cb_src_image.observe(on_ld_col_changed, names="value")
cb_row.observe(on_ld_row_changed, names="value")
cb_col.observe(on_ld_col_changed, names="value")

display(
    VBox(
        [
            HBox([cb_src_image, cb_row, cb_col]),
            HBox([op_full_plate]),
            HBox([op_leafdisk_no_bck, op_leafdisk]),
        ]
    )
)
