# Imports

In [None]:
import json
import random

import ipywidgets as widgets
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import Button, HBox, VBox
from PIL import Image as PilImage
from siuba import _, filter, group_by, summarize
from sklearn.cluster import KMeans

# %matplotlib widget

# Wrangle dataframe

## Load source dataframe

In [None]:
df = (
    pd.read_csv(
        "../data_in/fancy-a-cup-of-marchantia-classifications_final.csv"
    )
    .sort_values(["user_name", "created_at", "subject_ids"])
    .reset_index()
)
df = df[df.workflow_name == "Draw rectangles around the gemma cups"]
df.head()

In [None]:
df.columns

In [None]:
df.shape

## Expand columns

### Expand metadata

In [None]:
metadata = df["metadata"].apply(lambda x: json.loads(x)).apply(pd.Series)
metadata.columns

In [None]:
subject_selection_state = metadata["subject_selection_state"].apply(pd.Series).drop([0], axis=1)
subject_selection_state.columns

In [None]:
subject_selection_state.head()

In [None]:
viewport = metadata["viewport"].apply(pd.Series)
viewport.columns

In [None]:
interventions = metadata["interventions"].apply(pd.Series).drop([0], axis=1)

In [None]:
interventions.head()

In [None]:
subject_dimensions = metadata["subject_dimensions"].apply(pd.Series)
subject_dimensions.columns

In [None]:
subject_dimensions.head()

In [None]:
subject_dimensions[0].apply(pd.Series)

In [None]:
df_xµd = pd.concat(
    [
        df.drop(["metadata"], axis=1),
        subject_selection_state,
        viewport,
        metadata.drop(
            [
                "viewport",
                "interventions",
                "subject_dimensions",
                "subject_selection_state",
            ],
            axis=1,
        ),
    ],
    axis=1,
)
# df_xµd = pd.concat(
#     [
#         df_xµd.drop([0], axis=1),
#         df_xµd[0].apply(pd.Series),
#     ],
#     axis=1,
# )
df_xµd.head(3)

In [None]:
df_xµd.columns

### Avoid same name column issues

In [None]:
df_xµd["retired_bool"] = df_xµd.retired
df_xµd = df_xµd.drop(["retired"], axis=1)
df_xµd["creation_time"] = df_xµd.created_at
df_xµd = df_xµd.drop(["created_at"], axis=1)
df_xµd

### Split subject_data and annotations

In [None]:
dfx = pd.concat(
    [
        df_xµd.drop(["subject_data", "annotations"], axis=1),
        df_xµd[
            "subject_data"
        ].apply(
            lambda x: list(json.loads(x).values())[0]
        ).apply(
            pd.Series
        ),
        df_xµd["annotations"].str[1:-1].apply(lambda x: json.loads(x)).apply(pd.Series),
    ],
    axis=1,
)
dfx = (
    pd.concat(
        [
            dfx.drop(["retired"], axis=1),
            dfx["retired"].apply(pd.Series),
        ],
        axis=1,
    )
    .rename(str.lower, axis="columns")
    .sort_values(by=["user_name", "filename", "created_at"])
)
dfx.head(3)

In [None]:
dfx.columns

In [None]:
df_xµd["annotations"].str[1:-1].apply(lambda x: json.loads(x)).apply(pd.Series)

In [None]:
df_xµd[
    "subject_data"
].apply(
    lambda x: list(json.loads(x).values())[0]
).apply(
    pd.Series
)

### Fix user_agent

In [None]:
dfx.user_agent = dfx.user_agent.str[13:28]
dfx.user_agent.unique()

In [None]:
dfx.user_agent = np.select(
    [
        dfx.user_agent == "Windows NT 10.0",
        dfx.user_agent == "Macintosh; Inte",
        dfx.user_agent == "Linux; Android ",
        dfx.user_agent == "X11; CrOS x86_6",
        dfx.user_agent == "Windows NT 6.1;",
        dfx.user_agent == "iPhone; CPU iPh",
        dfx.user_agent == "X11; Linux x86_",
        dfx.user_agent == "Windows NT 6.3;",
        dfx.user_agent == "iPad; CPU OS 12",
    ],
    [
        "Windows 10",
        "OSX",
        "Android",
        "Chrome OS",
        "Windows 7",
        "iPhone",
        "Linux",
        "Windows 8.1",
        "iPad",
    ],
    default="Oops",
)
dfx.user_agent.unique()

### Keep only needed columns

In [None]:
dfx.shape

In [None]:
dfx.columns

In [None]:
df_keep = dfx[
    [
        "user_name",
        "user_agent",
        "filename",
        "value",
        "classifications_count",
        "creation_time",
    ]
]
df_keep.head(3)

### Count observations

In [None]:
df_keep["rect_count"] = df_keep.value.apply(lambda x: len(x))
df_keep.head()
# df_keep.value.apply(lambda x: len(x))

### Tidy up

In [None]:
df_keep.columns

In [None]:
tidy = (
    pd.concat(
        [
            df_keep.drop(["value"], axis=1), 
            df_keep.value.apply(pd.Series)
        ], 
        axis=1
    ).melt(
        id_vars=[
            "user_name",
            "filename",
            "user_agent",
            "rect_count",
            "classifications_count",
            "creation_time",
        ],
        var_name="dummy",
        value_name="rectangle",
    )
    .drop(["dummy"], axis=1)
    .dropna(subset=["user_name", "filename", "user_agent", "rect_count"])
)

tidy = pd.concat(
    [tidy.drop(["rectangle"], axis=1), tidy.rectangle.apply(pd.Series)], axis=1
)[
    [
        "user_name",
        "filename",
        "creation_time",
        "user_agent",
        "rect_count",
        "classifications_count",
        "x",
        "y",
        "width",
        "height",
    ]
]

tidy = (tidy >> filter(_.classifications_count >= 0)).reset_index().drop_duplicates()
tidy

#  Explore the data

In [None]:
tidy[
    (tidy.user_name == "Brooker1957")
    & (tidy.filename == "b0xhA8TCuQtLRbirX369iE7dJvUE.jpg")
].dropna(
).reset_index(
).drop(
    ["index"], axis=1
).drop_duplicates(
).sort_values(
    ["user_name", "filename", "x", "y", "width", "height"]
).dropna()

In [None]:
no_rect_df = tidy[
    ["user_name", "filename", "creation_time", "user_agent", "rect_count"]
].drop_duplicates()
df_stats = (
    no_rect_df.groupby("filename")
    .rect_count.agg(
        count="count",
        min="min",
        max="max",
        mean="mean",
        median="median",
        std="std",
        mode=lambda x: x.mode(),
    )
    .reset_index()
    .sort_values("filename")
)
df_stats

In [None]:
observations = ["Select an observation"] + sorted(df_stats.filename.unique().tolist())
obs_selected = widgets.Dropdown(
    options=observations,
    description="Select an observation:",
)

user_selected = widgets.Dropdown(
    options=[],
    description="Select a user:",
)

date_selected = widgets.Dropdown(
    options=[],
    description="Select a date:",
)

shape_selected = widgets.Dropdown(
    options=["Rectangle", "Circle"],
    description="Draw shape:",
    value="Rectangle",
)

button = widgets.Button(description="Render")

image_with_rects = widgets.Output(layout={"border": "1px solid black"})
text_rects = widgets.Output(layout={"border": "1px solid black"})


dataframe: widgets.Output = widgets.Output(layout={"border": "1px solid black"})
stats_output = widgets.Output(layout={"border": "1px solid black"})


def update_overview(observation, user, date_, shape, update_user, update_date):
    dataframe.clear_output()
    with dataframe:
        display(
            no_rect_df.drop("filename", axis=1)[
                no_rect_df.filename == observation
            ].reset_index()
        )

    stats_output.clear_output()
    with stats_output:
        display(df_stats[df_stats.filename == observation].reset_index())

    if update_user == "update":
        user_selected.options = ["None", "All"] + sorted(
            tidy[tidy.filename == observation].user_name.unique().tolist()
        )
    elif update_user == "clear":
        user_selected.options = []
    if update_date == "update":
        date_selected.options = ["All"] + sorted(
            tidy[(tidy.filename == observation) & (tidy.user_name == user)]
            .creation_time.unique()
            .tolist()
        )
    elif update_date == "clear":
        date_selected.options = []

#     if observation == "Select an observation":
#         return
#     elif user == "None":
#         rects = None
#     elif user == "All":
    rects = tidy[(tidy.filename == observation)]
#     else:
#         rects = tidy[(tidy.filename == observation) & (tidy.user_name == user)]

#     if rects is None:
#         pass
#     elif date_ == "None":
#         rects = None
#     elif date_ != "All":
#         rects = rects[rects.creation_time == date_]

    if rects is not None:
        rects = rects.reset_index().dropna()[["x", "y", "width", "height"]]
    image_with_rects.clear_output()
    with image_with_rects:
        img = PilImage.open(f"../data_in/images/{obs_selected.value}")
        fig, ax = plt.subplots()
        fig.set_size_inches(14, 14)
        ax.set_axis_off()
        ax.imshow(img)
        if rects is not None:
            rects = rects.reset_index().dropna()[["x", "y", "width", "height"]]
            colors = [
                random.choice(list(mcolors.CSS4_COLORS.keys()))
                for _ in range(rects.shape[0])
            ]
            for x, y, w, h, c in zip(
                rects.x, rects.y, rects.width, rects.height, colors
            ):
                if shape == "Circle":
                    ax.add_patch(
                        patches.Circle(
                            (x + w // 2, y + h // 2),
                            8,
                            linewidth=8,
                            edgecolor=c,
                            facecolor="none",
                        )
                    )
                elif shape == "Rectangle":
                    ax.add_patch(
                        patches.Rectangle(
                            (x, y), w, h, linewidth=2, edgecolor=c, facecolor="none"
                        )
                    )
        plt.show()

    text_rects.clear_output()
    with text_rects:
        if rects is not None:
            display(rects)


def on_observation_selected(change):
    update_overview(
        observation=change.new,
        user=user_selected.value,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="update",
        update_date="clear",
    )


def on_user_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=change.new,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="",
        update_date="update",
    )


def on_timestamp_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=user_selected.value,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="",
        update_date="",
    )


def on_shape_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=user_selected.value,
        date_=date_selected.value,
        shape=change.new,
        update_user="",
        update_date="",
    )


obs_selected.observe(on_observation_selected, names="value")
# user_selected.observe(on_user_selected, names="value")
# date_selected.observe(on_timestamp_selected, names="value")
# shape_selected.observe(on_shape_selected, names="value")
# button.on_click(on_button_clicked)

display(
    HBox([obs_selected, user_selected, date_selected, shape_selected, button]),
    stats_output,
    HBox([dataframe, image_with_rects, text_rects]),
)

In [None]:
images_list = ["Select an observation"] + sorted(df_stats.filename.unique().tolist())
dd_image = widgets.Dropdown(
    options=images_list,
    description="Select an observation:",
)

is_print_all = widgets.Checkbox(
    value=False,
    description="Print all annotations centers",
    disabled=False,
    indent=False,
)
is_print_centers = widgets.Checkbox(
    value=False, description="Print all kmeans centers", disabled=False, indent=False
)
is_print_rectangles = widgets.Checkbox(
    value=True, description="Print rectangles", disabled=False, indent=False
)

image_output = widgets.Output(layout={"border": "1px solid black"})
image_stat_output = widgets.Output(layout={"border": "1px solid black"})
user_gt = widgets.Output(layout={"border": "1px solid black"})
value_counts = widgets.Output(layout={"border": "1px solid black"})


def print_ground_truth(
    observation: str,
    print_all: bool = False,
    print_centers: bool = False,
    print_rectangles: bool = True,
):
    image_stat_output.clear_output()
    with image_stat_output:
        display(df_stats[df_stats.filename == observation].reset_index())

    vc = no_rect_df[no_rect_df.filename == observation]["rect_count"].value_counts()
    if len(vc) > 1:
        v = vc.index.to_list()[:2]
        c = vc.to_list()[:2]
        if c[0] > 2 * c[1]:
            allowed_counts = [v[0]]
        else:
            allowed_counts = v
    else:
        allowed_counts = vc.index.to_list()

    rects = tidy[
        (tidy.filename == observation) & (tidy.rect_count.isin(allowed_counts))
    ].dropna()[["x", "y", "width", "height"]]
    rects = rects[
        (rects.width < 2 * rects.width.median())
        & (rects.height < 2 * rects.height.median())
    ].reset_index()
    image_output.clear_output()
    with image_output:
        img = PilImage.open(f"../data_in/images/{observation}")
        fig, ax = plt.subplots()
        fig.set_size_inches(14, 14)
        ax.set_axis_off()
        ax.imshow(img)
        if (rects is not None) and (rects.shape[0] > 0):
            rects = rects.assign(
                cx=rects.x + rects.width // 2,
                cy=rects.y + rects.height // 2,
            )
            X = [(cx, cy) for cx, cy in zip(rects.cx, rects.cy)]
            kmeans = KMeans(
                n_clusters=max(allowed_counts),
                random_state=42,
            ).fit(X)
            y_pred = kmeans.predict(X)
            if print_all:
                ax.scatter(rects.cx, rects.cy, c=y_pred, alpha=0.5)
            if print_rectangles:
                rects["y_pred"] = y_pred
                rects = rects.groupby("y_pred").median().reset_index()
                for x, y, w, h in zip(rects.x, rects.y, rects.width, rects.height):
                    ax.add_patch(
                        patches.Rectangle(
                            (x, y), w, h, linewidth=2, edgecolor="r", facecolor="none"
                        )
                    )
            if print_centers:
                centers = kmeans.cluster_centers_
                ax.scatter(centers[:, 0], centers[:, 1], c="b", s=200, alpha=0.5)
        plt.show()

    user_gt.clear_output()
    with user_gt:
        display(
            no_rect_df[no_rect_df.filename == observation]
            .reset_index()
            .drop(["filename", "user_agent"], axis=1)
        )

    value_counts.clear_output()
    with value_counts:
        display(
            no_rect_df[no_rect_df.filename == observation]["rect_count"].value_counts()
        )


def on_image_selected(change):
    print_ground_truth(
        observation=change.new,
        print_all=is_print_all.value,
        print_centers=is_print_centers.value,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_all_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=change.new,
        print_centers=is_print_centers.value,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_centers_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=is_print_all.value,
        print_centers=change.new,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_rectangles_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=is_print_all.value,
        print_centers=is_print_centers.value,
        print_rectangles=change.new,
    )


dd_image.observe(on_image_selected, names="value")

is_print_all.observe(on_print_all_changed, names="value")
is_print_centers.observe(on_print_centers_changed, names="value")
is_print_rectangles.observe(on_print_rectangles_changed, names="value")

display(
    HBox([dd_image, is_print_all, is_print_centers, is_print_rectangles]),
    image_stat_output,
    HBox([image_output, user_gt, value_counts]),
)

In [None]:
df_lst = []
for filename in df_stats.filename.unique().tolist():
    vc = no_rect_df[no_rect_df.filename == filename]["rect_count"].value_counts()
    if len(vc) > 1:
        v = vc.index.to_list()[:2]
        c = vc.to_list()[:2]
        if c[0] > 2 * c[1]:
            allowed_counts = [v[0]]
        else:
            allowed_counts = v
    else:
        allowed_counts = vc.index.to_list()

    rects = (
        tidy[(tidy.filename == filename) & (tidy.rect_count.isin(allowed_counts))]
        .drop(
            [
                "index",
                "user_name",
                "creation_time",
                "user_agent",
                "rect_count",
                "classifications_count",
            ],
            axis=1,
        )
        .dropna()
    )
    rects = rects[
        (rects.width < 2 * rects.width.median())
        & (rects.height < 2 * rects.height.median())
    ].reset_index()
    if (rects is not None) and (rects.shape[0] > 0):
        X = [
            (cx, cy)
            for cx, cy in zip(rects.x + rects.width // 2, rects.y + rects.height // 2)
        ]
        rects["y_pred"] = (
            KMeans(
                n_clusters=max(allowed_counts),
                random_state=42,
            )
            .fit(X)
            .predict(X)
        )
        df_lst.append(
            rects.assign(
                x=rects.groupby("y_pred", dropna=True,)[
                    "x"
                ].transform("median"),
                y=rects.groupby("y_pred", dropna=True,)[
                    "y"
                ].transform("median"),
                width=rects.groupby("y_pred", dropna=True,)[
                    "width"
                ].transform("median"),
                height=rects.groupby("y_pred", dropna=True,)[
                    "height"
                ].transform("median"),
            )
            .reset_index(drop=True)
            .drop(["y_pred", "index"], axis=1)
            .drop_duplicates()
            .reset_index(drop=True)
        )
    else:
        df_lst.append(
            pd.DataFrame(
                [[filename, np.NaN, np.NaN, np.NaN, np.NaN]],
                columns=["filename", "x", "y", "width", "height"],
            )
        )

In [None]:
df_final = pd.concat(df_lst)
df_final.to_csv(
    path_or_buf="../data_out/boxes_final.csv",
    index=False,
)
df_final

In [None]:
dd_final_image = widgets.Dropdown(options=sorted(df_final.filename.unique().tolist()))

final_image_output = widgets.Output(layout={"border": "1px solid black"})
rects_output = widgets.Output(layout={"border": "1px solid black"})


def print_final_rects(change):
    final_image_output.clear_output()
    rects = df_final[df_final.filename == change.new]
    with final_image_output:
        img = PilImage.open(f"../data_in/images/{change.new}")
        fig, ax = plt.subplots()
        fig.set_size_inches(14, 14)
        ax.set_axis_off()
        ax.imshow(img)
        if (rects is not None) and (rects.shape[0] > 0):
            for x, y, w, h in zip(rects.x, rects.y, rects.width, rects.height):
                ax.add_patch(
                    patches.Rectangle(
                        (x, y),
                        w,
                        h,
                        linewidth=2,
                        edgecolor="r",
                        facecolor="none",
                    )
                )
        plt.show()
    rects_output.clear_output()
    with rects_output:
        display(rects)


dd_final_image.observe(print_final_rects, names="value")
display(dd_final_image, HBox([final_image_output, rects_output]))