# Gemma cup prediction

## Imports

In [1]:
import os

import pandas as pd
import numpy as np
import cv2

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from tqdm.notebook import tqdm

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import Button, HBox, VBox

import torch
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from skimage import transform

from PIL import Image as PilImage

## Define constants

In [2]:
data_path = os.path.join("..", "data_in")
images_path =os.path.join(data_path, "images")

## Load dataframe

In [3]:
df: pd.DataFrame = pd.read_csv(
    os.path.join(data_path, "filename_to_hash_v2.csv")
).assign(
    filename=lambda x: x.hash
)
df.sort_values(["hash"]).head()

Unnamed: 0,experiment,plant,date_time,camera,view_option,hash,date,time,filename
1772,10ac_mpo1_1904,10ac309_ic_mock_xx_309,2019-05-12 10:10:09,msp,sw755,b-1HoJ-Hqz5STrwrZHGBYdjAE3Q.jpg,2019-05-12,10:10:09,b-1HoJ-Hqz5STrwrZHGBYdjAE3Q.jpg
3542,10ac_mpo1_1904,10ac79_nc_mock_xx_79,2019-05-12 05:46:06,msp,sw755,b-38OOnRNVd8OdKdmNYZXXy83E.jpg,2019-05-12,05:46:06,b-38OOnRNVd8OdKdmNYZXXy83E.jpg
902,10ac_mpo1_1904,10ac184_hb_mock_xx_184,2019-05-04 07:12:57,msp,sw755,b-3Q-HdqeuB2sRxMIzzUPNjZfLSw.jpg,2019-05-04,07:12:57,b-3Q-HdqeuB2sRxMIzzUPNjZfLSw.jpg
1847,10ac_mpo1_1904,10ac315_bu_mock_xx_315,2019-05-16 10:14:27,msp,sw755,b-3wnyR8oNsu-V149ZYfCL-dfeDk.jpg,2019-05-16,10:14:27,b-3wnyR8oNsu-V149ZYfCL-dfeDk.jpg
3625,10ac_mpo1_1904,10ac86_c2_mock_xx_86,2019-05-18 05:47:45,msp,sw755,b-4-5gzYj0rmr9-dvGMkA3-FNkEs.jpg,2019-05-18,05:47:45,b-4-5gzYj0rmr9-dvGMkA3-FNkEs.jpg


In [4]:
df.shape

(3769, 9)

## Test link to images

In [5]:
dd_sample = widgets.Dropdown(options=sorted(df.hash.to_list()))

image_output = widgets.Output(layout={"border": "1px solid black"})

def predict_gemma_cups(change):
    image_output.clear_output()    
    with image_output:
        fig, ax = plt.subplots(1, 1, figsize=(16, 8))
        ax.set_axis_off()
        ax.imshow(PilImage.open(os.path.join(images_path, change.new)))
        plt.show()

    
dd_sample.observe(predict_gemma_cups, names="value")
display(dd_sample, image_output)

Dropdown(options=('b-1HoJ-Hqz5STrwrZHGBYdjAE3Q.jpg', 'b-38OOnRNVd8OdKdmNYZXXy83E.jpg', 'b-3Q-HdqeuB2sRxMIzzUPN…

Output(layout=Layout(border='1px solid black'))

## Set device

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cpu')

## Predict

### Define image loader

In [7]:
def image_loader(image_name):
    image = cv2.imread(
        image_name,
        cv2.IMREAD_COLOR,
    )
    image = cv2.cvtColor(
        image, 
        cv2.COLOR_BGR2RGB
    ).astype(np.float32)
    image /= 255.0
    image = transforms.ToTensor()(image)
    return image.to(device)

### Load model

In [8]:
# loaded_model = torch.load(os.path.join("..", "models", "default_model.pth"))
loaded_model = torch.load(os.path.join("..", "models", "20210407_200_model.pth"))
loaded_model.eval();

### Prediction cache data frame

In [9]:
cache_path = path_or_buf=os.path.join(
    "..", 
    "data_out", 
    "predictions_cache.csv"
)

In [10]:
dfp = pd.read_csv(cache_path)

In [11]:
dfp.describe()

Unnamed: 0,x1,y1,x2,y2,score
count,3369.0,3369.0,3369.0,3369.0,3369.0
mean,751.560063,764.981823,805.656039,818.602066,0.62807
std,155.055634,159.958966,154.264105,160.545321,0.393645
min,43.064606,22.946793,96.588043,111.769791,0.050009
25%,642.514648,664.517578,698.356018,721.35968,0.159753
50%,743.159973,766.627869,798.286011,820.228027,0.898524
75%,857.062134,868.411621,911.55957,920.940063,0.980238
max,1599.763428,1599.624146,1600.0,1599.994995,0.994229


### Prediction function

In [15]:
def predict_boxes(hash, threshold, swap_colors: bool = True, show_discarded: bool=False):
    global dfp
    
    tmp = dfp[dfp.hash == hash]
    if tmp.shape[0] > 0:
        boxes = [[x1, y1, x2, y2] for x1, y1, x2, y2 in zip(tmp.x1, tmp.y1, tmp.x2, tmp.y2)]
        scores = [s for s in tmp.score.to_list()]
    else:
        images = [image_loader(os.path.join(images_path, hash))]
        res = loaded_model(images)    
        boxes = res[0]["boxes"].data.cpu().numpy()
        scores = res[0]["scores"].data.cpu().numpy()
        tmp = pd.DataFrame(
            {
                "hash": [hash for _ in scores], 
                "x1": [b[0] for b in boxes],
                "y1": [b[1] for b in boxes],
                "x2": [b[2] for b in boxes], 
                "y2": [b[3] for b in boxes], 
                "score": [s for s in scores]
            }
        )
        dfp = pd.concat([dfp, tmp])
        dfp.to_csv(
            cache_path,
            index=False,
        )
        
    img = cv2.imread(
        os.path.join(images_path, hash),
        cv2.IMREAD_COLOR,
    )
    if swap_colors is True:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for box, score in zip(boxes, scores):
        if score < threshold and show_discarded is False:
            continue
        elif  score < threshold:
            color = (255,0,0)
        else:
            color = (int((1- score) * 255), 0, int(score * 255)) if swap_colors is True else (int(score * 255), 0, int((1 - score) * 255))
        cv2.rectangle(
            img,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 
            3
        )
    
    return boxes, scores, img, tmp

### Predicion widget

In [16]:
dd_sample = widgets.Dropdown(options=sorted(df.hash.to_list()))

image_output = widgets.Output(layout={"border": "1px solid black"})
rects_output = widgets.Output(layout={"border": "1px solid black"})
score_threshold = widgets.FloatSlider(
    value=0.8, 
    min=0, 
    max=1.0, 
    decsiption="Score threshold"
)
cb_show_discarded = widgets.Checkbox(value=False, description = "show discarded")

def predict_gemma_cups(hash, threshold,show_discarded):
    
    
    image_output.clear_output()
    rects_output.clear_output()
    
    boxes, scores, img, tmp_df = predict_boxes(
        hash=hash, 
        threshold=threshold,
        show_discarded=show_discarded,
    )
    
    with image_output:
        fig, ax = plt.subplots(1, 1, figsize=(14, 14))
        ax.set_axis_off()
        ax.imshow(img)        
        plt.show()
    
    with rects_output:
        display(
            tmp_df.sort_values(
                ["score"], 
                ascending=False
            ).reset_index()
        )
        
def on_image_changed(change):
    predict_gemma_cups(change.new, score_threshold.value, show_discarded=cb_show_discarded.value)
        
def on_threshold_changed(change):
    predict_gemma_cups(dd_sample.value, change.new, show_discarded=cb_show_discarded.value)
        
def on_show_discarded_changed(change):
    predict_gemma_cups(dd_sample.value, score_threshold.value, show_discarded=change.new,)

    
dd_sample.observe(on_image_changed, names="value")
score_threshold.observe(on_threshold_changed, names="value")
cb_show_discarded.observe(on_show_discarded_changed, names="value")

display(
    HBox([dd_sample, score_threshold, cb_show_discarded]), 
    HBox([image_output, rects_output])
)

HBox(children=(Dropdown(options=('b-1HoJ-Hqz5STrwrZHGBYdjAE3Q.jpg', 'b-38OOnRNVd8OdKdmNYZXXy83E.jpg', 'b-3Q-Hd…

HBox(children=(Output(layout=Layout(border='1px solid black')), Output(layout=Layout(border='1px solid black')…

## Build videos

In [None]:
plants = widgets.SelectMultiple(
    options=sorted(df.plant.unique()),
    value=[],
    #rows=10,
    description='Plants',
    disabled=False
)
threshold = widgets.FloatSlider(
    value=0.80, 
    min=0, 
    max=1.0,
    step=0.05,
    description="Score threshold"
)
build_video = widgets.Button(description="Build video")
progress_output = widgets.Output()
output = widgets.Output()

display(HBox([plants, threshold, build_video]), progress_output, output)

def on_button_clicked(b):
    progress_output.clear_output()
    
    with progress_output:
        with tqdm(total=len(plants.value)) as gpbar:
            for plant in plants.value:    
                df_tmp = df[
                    df.plant == plant
                ].sort_values(
                    ["date_time"]
                ).reset_index()
                output.clear_output()

                with output:
                    frame_rate = 24.0
                    frame_duration = 6
                    v_height, v_width = 640, 640
                    v_output = os.path.join(
                        "..", 
                        "data_out", 
                        "videos", 
                        f"{plant}_{threshold.value:.2f}.mp4"
                    )

                    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
                    out = cv2.VideoWriter(v_output, fourcc, frame_rate, (v_width, v_height))



                    with tqdm(total=df_tmp.shape[0]) as pbar:
                        for hash in df_tmp.hash.to_list():
                            _, _, img, _ = predict_boxes(
                                hash=hash, 
                                threshold=threshold.value, 
                                swap_colors=False,
                            )
                            img = cv2.resize(
                                img, 
                                (v_width, v_height), 
                                interpolation=cv2.INTER_CUBIC
                            )                
                            for _ in range(0, frame_duration):
                                out.write(img)
                            pbar.update(1)
                gpbar.update(1)

build_video.on_click(on_button_clicked)