# Leaf disk sheat segmentation and disk coordonates detection

## Imports

In [None]:
import os
import sys

%load_ext autoreload
%autoreload 2

from collections import defaultdict
import copy
import random

from sklearn.model_selection import train_test_split

import albumentations as A
import albumentations.augmentations.functional as F
import albumentations.augmentations.geometric as G

from sklearn.model_selection import train_test_split

from albumentations.pytorch import ToTensorV2

import cv2
import matplotlib.pyplot as plt
import numpy as np

import ternausnet.models
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import Button, HBox, VBox


sys.path.insert(0, os.path.join("..", "scripts"))

import ld_dataset as ldd
import ld_plot as ldp
import ld_train_helpers as ldth
import ld_image as ldi


## Prepare envionment

In [None]:
device = "mps" if torch.backends.mps.is_built() is True else "cpu"

device


## Images

In [None]:
df_train_images = ldd.build_items_dataframe(
    images_folder=ldd.train_images_folder,
    masks_folder=ldd.train_masks_folder,
)
df_train_images

In [None]:
df_train_images.shape

In [None]:
train, test = train_test_split(df_train_images, test_size=0.3, stratify=df_train_images["year"])
test, val = train_test_split(test, test_size=0.5, stratify=test["year"])

print(len(train), len(test), len(val))


In [None]:
ldp.display_image_grid(val.sample(n=5))


## Albumentations

In [None]:
alb_resizer = A.Resize(height=256, width=256)

train_transformers_list = [
    alb_resizer,
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Transpose(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.3),
    A.RandomGamma(p=0.3),
    A.CLAHE(p=0.36),
]

transformer = A.Compose(train_transformers_list)


image, mask = ldd.open_image_and_mask(0, df_train_images)

transformed = transformer(image=image, mask=mask)

image_transformed = transformed["image"]
mask_transformed = transformed["mask"]


ldp.visualize_augmented_item(
    image_transformed, mask_transformed, original_image=image, original_mask=mask
)


In [None]:
original_height, original_width, original_channel_count = image.shape

(original_width, original_height)


In [None]:
image_transformed.shape


In [None]:
mask_transformed.shape


In [None]:
train_transformer = A.Compose(
    train_transformers_list
    + [
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
)

val_transformer = A.Compose(
    [
        alb_resizer,
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

test_transformer = A.Compose(
    [
        alb_resizer,
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)


## Datasets

In [None]:
train_dataset = ldd.LeafDeafSegmentationDataset(
    df_img=train,
    transform=train_transformer,
)


In [None]:
val_dataset = ldd.LeafDeafSegmentationDataset(
    df_img=val,
    transform=val_transformer,
)


In [None]:
random.seed(42)
ldp.visualize_augmented_dataset_item(train_dataset, idx=0, samples=4)


## Create and train model

In [None]:
params = {
    "model": "UNet11",
    "device": device,
    "lr": 0.001,
    "batch_size": 5,
    "num_workers": 0,
    "epochs": 30,
}


In [None]:
model = ldth.create_model(params)
model = ldth.train_and_validate(model, train_dataset, val_dataset, params)


## Test model

In [None]:
test_sample = test.sample(n=10)

test_dataset = ldd.LeafDeafSegmentationInferenceDataset(
    image_list=test_sample.image_path.to_list(),
    transform=test_transformer,
    dataframe=test_sample,
)


predictions = ldth.test_model(model, params, test_dataset, batch_size=16)


In [None]:
predicted_masks = []
for (predicted_256x256_mask, original_height, original_width) in predictions:
    predicted_masks.append(
        A.resize(
            predicted_256x256_mask,
            height=original_height,
            width=original_width,
            interpolation=cv2.INTER_NEAREST,
        )
    )


In [None]:
ldp.display_image_grid(test_sample, predicted_masks=predicted_masks)


## Predict single annotated image image

### Predicted mask

In [None]:
row = df_train_images.sample(n=1)
img_path = row.image_path.to_list()[0]
src_image = ldd.open_image(img_path)

predicted_mask = ldth.predict_image(
    img_path,
    model=model,
    params=params,
    threshold=0.5,
    img_transformer=test_transformer,
)

ldp.visualize_item(src_image, predicted_mask)


### Print mask as overlay

In [None]:
ldp.visualize_image(cv2.bitwise_and(src_image, src_image, mask=predicted_mask))


In [None]:
bck_grd_luma = 0.3
lum, a, b = cv2.split(cv2.cvtColor(src_image, cv2.COLOR_BGR2LAB))
lum = (lum * bck_grd_luma).astype(np.uint)
lum[lum >= 255] = 255
lum = lum.astype(np.uint8)
background_img = cv2.merge((lum, a, b))
background_img = cv2.cvtColor(background_img, cv2.COLOR_LAB2BGR)

ldp.visualize_image(
    cv2.bitwise_and(background_img, background_img, mask=255 - predicted_mask)
)


In [None]:
ldp.visualize_image(
    cv2.bitwise_or(
        cv2.bitwise_and(background_img, background_img, mask=255 - predicted_mask),
        cv2.bitwise_and(src_image, src_image, mask=predicted_mask),
    )
)


In [None]:
row = df_train_images.sample(n=1)
img_path = row.image_path.to_list()[0]


ldp.show_masked_image(
    image=ldd.open_image(img_path),
    mask=ldth.predict_image(
        img_path,
        model=model,
        params=params,
        threshold=0.5,
        img_transformer=test_transformer,
    ),
)


### Prediction widgert

In [None]:
dd_image = widgets.Dropdown(
    options=sorted([str(i) for i in ldd.train_images_folder.glob("*")]),
    description="Select an image:",
)

sl_luma = widgets.FloatSlider(
    value=0.3,
    min=0.0,
    max=2.0,
    step=0.1,
    description="Background luma",
)
sl_threshold = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=2.0,
    step=0.1,
    description="Prediction threshold",
)

output_image = widgets.Output()


def on_prediction_widget_changed(image, threshold, luma):
    output_image.clear_output()
    with output_image:
        ldp.show_masked_image(
            image=ldd.open_image(image),
            mask=ldth.predict_image(
                image,
                model=model,
                params=params,
                threshold=threshold,
                img_transformer=test_transformer,
            ),
            luma=luma,
        )


def on_image_changed(change):
    on_prediction_widget_changed(
        image=change.new,
        threshold=sl_threshold.value,
        luma=sl_luma.value,
    )


def on_threshold_changed(change):
    on_prediction_widget_changed(
        image=dd_image.value,
        threshold=change.new,
        luma=sl_luma.value,
    )


def on_luma_changed(change):
    on_prediction_widget_changed(
        image=dd_image.value,
        threshold=sl_threshold.value,
        luma=change.new,
    )


dd_image.observe(on_image_changed, names="value")
sl_threshold.observe(on_threshold_changed, names="value")
sl_luma.observe(on_luma_changed, names="value")

display(
    HBox([dd_image, sl_threshold, sl_luma]),
    output_image,
)


## Leaf disk indexing

In [None]:
img_path = dd_image.value


mask=ldth.predict_image(
    img_path,
    model=model,
    params=params,
    threshold=0.5,
    img_transformer=test_transformer,
)

ldp.visualize_item(ldd.open_image(img_path), mask, direction="re", figsize=(12,10))

In [None]:
np.unique(mask)

In [None]:
contours = ldi.index_contours(mask, threshold=0.8)

ldp.visualize_image(
    ldi.print_contours_indexs(
        mask,
        contours,
        canvas=ldi.print_contour_threshold(mask, threshold=0.8),
    ),
    figsize=(12, 8),
)


In [None]:
contours = ldi.index_contours(mask, threshold=0.8)

ldp.visualize_image(
    ldi.print_contours_indexs(
        mask,
        contours,
        canvas=ldd.open_image(img_path),
    ),
    figsize=(12, 8),
)
