# Build Zooniverse Dataset

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import json
import warnings

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import HBox


## Remove Warnings

In [None]:
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)


## Constants

In [None]:
data_in_path = Path("..").joinpath("data_in")

csvs_path = data_in_path.joinpath("csvs")
images_path = data_in_path.joinpath("images")
dataset_path = data_in_path.joinpath("datasets")

data_in_path.is_dir(), csvs_path.is_dir(), images_path.is_dir(), dataset_path.is_dir()


## Handle Zooniverse Raw Data

### Load Source Dataframe

In [None]:
df = (
    pd.read_csv(
        str(csvs_path.joinpath("fancy-a-cup-of-marchantia-classifications_final.csv"))
    )
    .sort_values(["user_name", "created_at", "subject_ids"])
    .reset_index()
)
df = df[df.workflow_name == "Draw rectangles around the gemma cups"]
df


### Expand Columns

#### Metadata

In [None]:
metadata = df["metadata"].apply(lambda x: json.loads(x)).apply(pd.Series)
metadata


In [None]:
subject_selection_state = (
    metadata["subject_selection_state"].apply(pd.Series).drop([0], axis=1)
)
subject_selection_state


In [None]:
viewport = metadata["viewport"].apply(pd.Series)
viewport


In [None]:
interventions = metadata["interventions"].apply(pd.Series).drop([0], axis=1)
interventions


In [None]:
subject_dimensions = metadata["subject_dimensions"].apply(pd.Series)
subject_dimensions


In [None]:
subject_dimensions[0].apply(pd.Series)


In [None]:
df_xµd = pd.concat(
    [
        df.drop(["metadata"], axis=1),
        subject_selection_state,
        viewport,
        metadata.drop(
            [
                "viewport",
                "interventions",
                "subject_dimensions",
                "subject_selection_state",
            ],
            axis=1,
        ),
    ],
    axis=1,
)
df_xµd


### Avoid issues with same name columns

In [None]:
df_xµd["retired_bool"] = df_xµd.retired
df_xµd = df_xµd.drop(["retired"], axis=1)
df_xµd["creation_time"] = df_xµd.created_at
df_xµd = df_xµd.drop(["created_at"], axis=1)
df_xµd


### Split subject_data and annotations

In [None]:
dfx = pd.concat(
    [
        df_xµd.drop(["subject_data", "annotations"], axis=1),
        df_xµd["subject_data"]
        .apply(lambda x: list(json.loads(x).values())[0])
        .apply(pd.Series),
        df_xµd["annotations"].str[1:-1].apply(lambda x: json.loads(x)).apply(pd.Series),
    ],
    axis=1,
)
dfx = (
    pd.concat(
        [
            dfx.drop(["retired"], axis=1),
            dfx["retired"].apply(pd.Series),
        ],
        axis=1,
    )
    .rename(str.lower, axis="columns")
    .sort_values(by=["user_name", "filename", "created_at"])
)
dfx


In [None]:
df_xµd["annotations"].str[1:-1].apply(lambda x: json.loads(x)).apply(pd.Series)


In [None]:
df_xµd["subject_data"].apply(lambda x: list(json.loads(x).values())[0]).apply(pd.Series)


### Fix User Agent

In [None]:
dfx.user_agent = dfx.user_agent.str[13:28]
dfx.user_agent.unique()


In [None]:
dfx.user_agent = np.select(
    [
        dfx.user_agent == "Windows NT 10.0",
        dfx.user_agent == "Macintosh; Inte",
        dfx.user_agent == "Linux; Android ",
        dfx.user_agent == "X11; CrOS x86_6",
        dfx.user_agent == "Windows NT 6.1;",
        dfx.user_agent == "iPhone; CPU iPh",
        dfx.user_agent == "X11; Linux x86_",
        dfx.user_agent == "Windows NT 6.3;",
        dfx.user_agent == "iPad; CPU OS 12",
    ],
    [
        "Windows 10",
        "OSX",
        "Android",
        "Chrome OS",
        "Windows 7",
        "iPhone",
        "Linux",
        "Windows 8.1",
        "iPad",
    ],
    default="Oops",
)
dfx.user_agent.unique()


### Keep only needed columns

In [None]:
dfx


In [None]:
df_keep = dfx[
    [
        "user_name",
        "user_agent",
        "filename",
        "value",
        "classifications_count",
        "creation_time",
    ]
]
df_keep


### Count Observations

In [None]:
df_keep["rect_count"] = df_keep.value.apply(lambda x: len(x))
df_keep


### Filter Images exists

In [None]:
df_keep.filename.unique().shape


In [None]:
existing_files = [
    f
    for f in df_keep.filename.sort_values().unique()
    if images_path.joinpath(f).is_file()
]
len(existing_files)


In [None]:
df_filter_files = df_keep[df_keep.filename.isin(existing_files)]
df_filter_files


### Tidy Up

In [None]:
from siuba import _, filter

tidy_melt = (
    pd.concat(
        [
            df_filter_files.drop(["value"], axis=1),
            df_filter_files.value.apply(pd.Series),
        ],
        axis=1,
    )
    .melt(
        id_vars=[
            "user_name",
            "filename",
            "user_agent",
            "rect_count",
            "classifications_count",
            "creation_time",
        ],
        var_name="dummy",
        value_name="rectangle",
    )
    .drop(["dummy"], axis=1)
    .dropna(subset=["user_name", "filename", "user_agent", "rect_count"])
)


In [None]:
rectangles_path = csvs_path.joinpath("rectangles.csv")
if rectangles_path.is_file():
    rectangles = pd.read_csv(rectangles_path)
else:
    rectangles = tidy_melt.rectangle.apply(pd.Series)
rectangles


In [None]:
rectangles.dropna(subset=["x"]).reset_index(drop=True)

In [None]:
tidy = pd.concat([tidy_melt.drop(["rectangle"], axis=1), rectangles,], axis=1,)[
    [
        "user_name",
        "filename",
        "creation_time",
        "user_agent",
        "rect_count",
        "classifications_count",
        "x",
        "y",
        "width",
        "height",
    ]
]


In [None]:
tidy.isna().any()


In [None]:
tidy_nona = tidy.dropna().reset_index(drop=True)
tidy_nona


### Explore Data

In [None]:
tidy_nona[
    (tidy_nona.user_name == "Brooker1957")
    & (tidy_nona.filename == "b0xhA8TCuQtLRbirX369iE7dJvUE.jpg")
].dropna().reset_index().drop(["index"], axis=1).drop_duplicates().sort_values(
    ["user_name", "filename", "x", "y", "width", "height"]
).dropna()


In [None]:
no_rect_df = tidy_nona[
    ["user_name", "filename", "creation_time", "user_agent", "rect_count"]
].drop_duplicates()
df_stats = (
    no_rect_df.groupby("filename")
    .rect_count.agg(
        count="count",
        min="min",
        max="max",
        mean="mean",
        median="median",
        std="std",
        mode=lambda x: x.mode(),
    )
    .reset_index()
    .sort_values("filename")
)
df_stats


In [None]:
import random
from PIL import Image as PilImage
import matplotlib.colors as mcolors

observations = ["Select an observation"] + sorted(df_stats.filename.unique().tolist())
obs_selected = widgets.Dropdown(
    options=observations,
    description="Select an observation:",
)

user_selected = widgets.Dropdown(
    options=[],
    description="Select a user:",
)

date_selected = widgets.Dropdown(
    options=[],
    description="Select a date:",
)

shape_selected = widgets.Dropdown(
    options=["Rectangle", "Circle"],
    description="Draw shape:",
    value="Rectangle",
)

button = widgets.Button(description="Render")

image_with_rects = widgets.Output(layout={"border": "1px solid black"})
text_rects = widgets.Output(layout={"border": "1px solid black"})


dataframe: widgets.Output = widgets.Output(layout={"border": "1px solid black"})
stats_output = widgets.Output(layout={"border": "1px solid black"})


def update_overview(observation, user, date_, shape, update_user, update_date):
    dataframe.clear_output()
    with dataframe:
        display(
            no_rect_df.drop("filename", axis=1)[
                no_rect_df.filename == observation
            ].reset_index()
        )

    stats_output.clear_output()
    with stats_output:
        display(df_stats[df_stats.filename == observation].reset_index())

    if update_user == "update":
        user_selected.options = ["None", "All"] + sorted(
            tidy_nona[tidy_nona.filename == observation].user_name.unique().tolist()
        )
    elif update_user == "clear":
        user_selected.options = []
    if update_date == "update":
        date_selected.options = ["All"] + sorted(
            tidy_nona[
                (tidy_nona.filename == observation) & (tidy_nona.user_name == user)
            ]
            .creation_time.unique()
            .tolist()
        )
    elif update_date == "clear":
        date_selected.options = []

    #     if observation == "Select an observation":
    #         return
    #     elif user == "None":
    #         rects = None
    #     elif user == "All":
    rects = tidy_nona[(tidy_nona.filename == observation)]
    #     else:
    #         rects = tidy_nona[(tidy_nona.filename == observation) & (tidy_nona.user_name == user)]

    #     if rects is None:
    #         pass
    #     elif date_ == "None":
    #         rects = None
    #     elif date_ != "All":
    #         rects = rects[rects.creation_time == date_]

    if rects is not None:
        rects = rects.reset_index().dropna()[["x", "y", "width", "height"]]
    image_with_rects.clear_output()
    with image_with_rects:
        img = PilImage.open(f"../data_in/images/{obs_selected.value}")
        fig, ax = plt.subplots()
        fig.set_size_inches(14, 14)
        ax.set_axis_off()
        ax.imshow(img)
        if rects is not None:
            rects = rects.reset_index().dropna()[["x", "y", "width", "height"]]
            colors = [
                random.choice(list(mcolors.CSS4_COLORS.keys()))
                for _ in range(rects.shape[0])
            ]
            for x, y, w, h, c in zip(
                rects.x, rects.y, rects.width, rects.height, colors
            ):
                if shape == "Circle":
                    ax.add_patch(
                        patches.Circle(
                            (x + w // 2, y + h // 2),
                            8,
                            linewidth=8,
                            edgecolor=c,
                            facecolor="none",
                        )
                    )
                elif shape == "Rectangle":
                    ax.add_patch(
                        patches.Rectangle(
                            (x, y), w, h, linewidth=2, edgecolor=c, facecolor="none"
                        )
                    )
        plt.show()

    text_rects.clear_output()
    with text_rects:
        if rects is not None:
            display(rects)


def on_observation_selected(change):
    update_overview(
        observation=change.new,
        user=user_selected.value,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="update",
        update_date="clear",
    )


def on_user_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=change.new,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="",
        update_date="update",
    )


def on_timestamp_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=user_selected.value,
        date_=date_selected.value,
        shape=shape_selected.value,
        update_user="",
        update_date="",
    )


def on_shape_selected(change):
    update_overview(
        observation=obs_selected.value,
        user=user_selected.value,
        date_=date_selected.value,
        shape=change.new,
        update_user="",
        update_date="",
    )


obs_selected.observe(on_observation_selected, names="value")
user_selected.observe(on_user_selected, names="value")
date_selected.observe(on_timestamp_selected, names="value")
shape_selected.observe(on_shape_selected, names="value")
# button.on_click(on_button_clicked)

display(
    HBox([obs_selected, user_selected, date_selected, shape_selected, button]),
    stats_output,
    HBox([dataframe, text_rects]),
    image_with_rects,
)


In [None]:
from sklearn.cluster import KMeans

images_list = ["Select an observation"] + sorted(df_stats.filename.unique().tolist())
dd_image = widgets.Dropdown(
    options=images_list,
    description="Select an observation:",
)

is_print_all = widgets.Checkbox(
    value=False,
    description="Print all annotations centers",
    disabled=False,
    indent=False,
)
is_print_centers = widgets.Checkbox(
    value=False, description="Print all kmeans centers", disabled=False, indent=False
)
is_print_rectangles = widgets.Checkbox(
    value=True, description="Print rectangles", disabled=False, indent=False
)

image_output = widgets.Output(layout={"border": "1px solid black"})
image_stat_output = widgets.Output(layout={"border": "1px solid black"})
user_gt = widgets.Output(layout={"border": "1px solid black"})
value_counts = widgets.Output(layout={"border": "1px solid black"})


def print_ground_truth(
    observation: str,
    print_all: bool = False,
    print_centers: bool = False,
    print_rectangles: bool = True,
):
    image_stat_output.clear_output()
    with image_stat_output:
        display(df_stats[df_stats.filename == observation].reset_index())

    vc = no_rect_df[no_rect_df.filename == observation]["rect_count"].value_counts()
    if len(vc) > 1:
        v = vc.index.to_list()[:2]
        c = vc.to_list()[:2]
        if c[0] > 2 * c[1]:
            allowed_counts = [v[0]]
        else:
            allowed_counts = v
    else:
        allowed_counts = vc.index.to_list()

    rects = tidy_nona[
        (tidy_nona.filename == observation)
        & (tidy_nona.rect_count.isin(allowed_counts))
    ].dropna()[["x", "y", "width", "height"]]
    rects = rects[
        (rects.width < 2 * rects.width.median())
        & (rects.height < 2 * rects.height.median())
    ].reset_index()
    image_output.clear_output()
    with image_output:
        img = PilImage.open(f"../data_in/images/{observation}")
        fig, ax = plt.subplots()
        fig.set_size_inches(14, 14)
        ax.set_axis_off()
        ax.imshow(img)
        if (rects is not None) and (rects.shape[0] > 0):
            rects = rects.assign(
                cx=rects.x + rects.width // 2,
                cy=rects.y + rects.height // 2,
            )
            X = [(cx, cy) for cx, cy in zip(rects.cx, rects.cy)]
            kmeans = KMeans(
                n_clusters=max(allowed_counts),
                random_state=42,
            ).fit(X)
            y_pred = kmeans.predict(X)
            if print_all:
                ax.scatter(rects.cx, rects.cy, c=y_pred, alpha=0.5)
            if print_rectangles:
                rects["y_pred"] = y_pred
                rects = rects.groupby("y_pred").median().reset_index()
                for x, y, w, h in zip(rects.x, rects.y, rects.width, rects.height):
                    ax.add_patch(
                        patches.Rectangle(
                            (x, y), w, h, linewidth=2, edgecolor="r", facecolor="none"
                        )
                    )
            if print_centers:
                centers = kmeans.cluster_centers_
                ax.scatter(centers[:, 0], centers[:, 1], c="b", s=200, alpha=0.5)
        plt.show()

    user_gt.clear_output()
    with user_gt:
        display(
            no_rect_df[no_rect_df.filename == observation]
            .reset_index()
            .drop(["filename", "user_agent"], axis=1)
        )

    value_counts.clear_output()
    with value_counts:
        display(
            no_rect_df[no_rect_df.filename == observation]["rect_count"].value_counts()
        )


def on_image_selected(change):
    print_ground_truth(
        observation=change.new,
        print_all=is_print_all.value,
        print_centers=is_print_centers.value,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_all_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=change.new,
        print_centers=is_print_centers.value,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_centers_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=is_print_all.value,
        print_centers=change.new,
        print_rectangles=is_print_rectangles.value,
    )


def on_print_rectangles_changed(change):
    print_ground_truth(
        observation=dd_image.value,
        print_all=is_print_all.value,
        print_centers=is_print_centers.value,
        print_rectangles=change.new,
    )


dd_image.observe(on_image_selected, names="value")

is_print_all.observe(on_print_all_changed, names="value")
is_print_centers.observe(on_print_centers_changed, names="value")
is_print_rectangles.observe(on_print_rectangles_changed, names="value")

display(
    HBox([dd_image, is_print_all, is_print_centers, is_print_rectangles]),
    # image_stat_output,
    # HBox([image_output, user_gt, value_counts]),
)


### Apply K-Means to Find the True Bounding Boxes

In [None]:
from tqdm import tqdm_notebook as tqdmj

df_lst = []
for filename in tqdmj(df_stats.filename.unique().tolist()):
    vc = no_rect_df[no_rect_df.filename == filename]["rect_count"].value_counts()
    if len(vc) > 1:
        v = vc.index.to_list()[:2]
        c = vc.to_list()[:2]
        if c[0] > 2 * c[1]:
            allowed_counts = [v[0]]
        else:
            allowed_counts = v
    else:
        allowed_counts = vc.index.to_list()

    rects = (
        tidy_nona[
            (tidy_nona.filename == filename)
            & (tidy_nona.rect_count.isin(allowed_counts))
        ]
        .drop(
            [
                "index",
                "user_name",
                "creation_time",
                "user_agent",
                "rect_count",
                "classifications_count",
            ],
            axis=1,
            errors="ignore",
        )
        .dropna()
    )
    rects = rects[
        (rects.width < 2 * rects.width.median())
        & (rects.height < 2 * rects.height.median())
    ].reset_index()
    if (rects is not None) and (rects.shape[0] > 0):
        X = [
            (cx, cy)
            for cx, cy in zip(rects.x + rects.width // 2, rects.y + rects.height // 2)
        ]
        rects["y_pred"] = (
            KMeans(
                n_clusters=max(allowed_counts),
                random_state=42,
            )
            .fit(X)
            .predict(X)
        )
        df_lst.append(
            rects.assign(
                x=rects.groupby("y_pred", dropna=True,)[
                    "x"
                ].transform("median"),
                y=rects.groupby("y_pred", dropna=True,)[
                    "y"
                ].transform("median"),
                width=rects.groupby("y_pred", dropna=True,)[
                    "width"
                ].transform("median"),
                height=rects.groupby("y_pred", dropna=True,)[
                    "height"
                ].transform("median"),
            )
            .reset_index(drop=True)
            .drop(["y_pred", "index"], axis=1)
            .drop_duplicates()
            .reset_index(drop=True)
        )
    else:
        df_lst.append(
            pd.DataFrame(
                [[filename, np.NaN, np.NaN, np.NaN, np.NaN]],
                columns=["filename", "x", "y", "width", "height"],
            )
        )


In [None]:
df_kmeans = pd.concat(df_lst)
df_kmeans


In [None]:
df_kmeans.filename.unique().shape


### Remove Empty Bounding Boxes

In [None]:
df_kmeans.loc[(df_kmeans.width == 0) | (df_kmeans.height == 0)]


In [None]:
df_no_empty_recs = (
    df_kmeans[((df_kmeans.width != 0) & (df_kmeans.height != 0))]
    .assign(
        x1=lambda x: x.x,
        y1=lambda x: x.y,
        x2=lambda x: x.x + x.width,
        y2=lambda x: x.y + x.height,
    )
    .reset_index(drop=True)
)
df_no_empty_recs


## Merge Zooniverse Data With TPMP data

### Load TPMP CSV Data

In [None]:
df_t = pd.read_csv(str(csvs_path.joinpath("filename_to_hash_v2.csv")))
df_t

###  Merge Data

In [None]:
from siuba import _, filter
import numpy as np

df_merge_zt = (
    (
        pd.merge(
            left=df_no_empty_recs.assign(hash=lambda x: x.filename)
            .drop(["filename"], axis=1)
            .sort_values("hash"),
            right=df_t,
            on="hash",
        )
        .sort_values(["experiment", "plant", "date", "time"])
        .assign(filename=lambda x: x.hash)
        .drop(["hash"], axis=1)
    )[
        [
            "experiment",
            "plant",
            "camera",
            "view_option",
            "date_time",
            "date",
            "time",
            "filename",
            "x",
            "y",
            "width",
            "height",
        ]
    ].assign(
        x1=lambda x: x.x,
        y1=lambda x: x.y,
        x2=lambda x: x.x + x.width,
        y2=lambda x: x.y + x.height,
    )
).reset_index(drop=True)

df_merge_zt


In [None]:
df_merge_zt.to_csv(str(csvs_path.joinpath("zooniverse_tpmp_data.csv")), index=False)

## Build Datasets

### Load Data

In [None]:
df_src = pd.read_csv(str(csvs_path.joinpath("zooniverse_tpmp_data.csv")))
df_src

In [None]:
df_src[(df_src.x1 >= df_src.x2) | (df_src.y1 >= df_src.y2)] 

### Split Datasets

#### Build Counts Dataframe

In [None]:
df_strat = df_src.groupby("filename").count().reset_index(drop=False)
df_strat["count"] = df_strat.x
df_strat = df_strat[["filename","count"]]
df_strat


In [None]:
df_count = pd.merge(left=df_src, right=df_strat, on=["filename"])
df_count

#### Split with Stratify on Count

In [None]:
len(df_count[["filename"]].filename.unique())

In [None]:
from sklearn.model_selection import train_test_split

df_stratify = (
    df_count[["filename", "count"]]
    .drop_duplicates()
    .sort_values("count")
    .reset_index(drop=True)
)
df_stratify["strata"]= np.where(df_stratify["count"] > 0, 1, 0)
df_stratify


In [None]:

train_files, test_files = train_test_split(
    df_stratify,
    test_size=0.2,
    stratify=df_stratify["strata"],
)
test_files, val_files = train_test_split(
    test_files,
    test_size=0.5,
    stratify=test_files["strata"],
)

len(train_files), len(test_files), len(val_files)


In [None]:

train = df_src[df_src.filename.isin(train_files.filename.to_list())].sort_values(["filename"])
test = df_src[df_src.filename.isin(test_files.filename.to_list())].sort_values(["filename"])
val = df_src[df_src.filename.isin(val_files.filename.to_list())].sort_values(["filename"])

len(train), len(test), len(val)


In [None]:
train

## Test Datasets

In [None]:
import loaders as lds

image_size = 512
tst_ds = lds.GemmaDataset(
    train,
    images_path=images_path,
    transform=lds.get_test_image_transform(image_size=image_size),
)

tst_ds[0][1]["boxes"]

### Test boxes

In [None]:
import matplotlib.pyplot as plt

plt.imshow(
    tst_ds.draw_image_with_boxes(filename=train.sample(n=1).filename.to_list()[0])
)
plt.tight_layout()
plt.axis("off")
plt.show()


### Test Suspect Image

In [None]:
from matplotlib.pyplot import figure

figure(figsize=(10, 10), dpi=80)

plt.imshow(tst_ds.draw_image_with_boxes(filename="b0xhA8TCuQtLRbirX369iE7dJvUE.jpg"))
plt.tight_layout()
plt.axis("off")
plt.show()


### Test Transformations/Augmentations

In [None]:
file_name = train.sample(n=1).filename.to_list()[0]

lds.make_patches_grid(
    images=[tst_ds.draw_image_with_boxes(filename=file_name) for _ in range(12)],
    row_count=3,
    col_count=4,
    figsize=(10, 7.5),
)


### Test Tracking

In [None]:
ds_plant = lds.GemmaDataset(
    csv=df_src[df_src.plant == df_src.sample(n=1).plant.to_list()[0]].sort_values(
        ["date_time"]
    ),
    images_path=images_path,
    transform=lds.get_resize_only_image_transform(image_size=image_size),
)

lds.make_patches_grid(
    images=[ds_plant.draw_image_with_boxes(filename=fn) for fn in ds_plant.images],
    row_count=3,
    col_count=4,
    figsize=(20, 15),
)


## Save Datasets

In [None]:
for d, n in zip([train, val, test], ["train", "val", "test"]):
    d.to_csv(str(dataset_path.joinpath(f"{n}.csv")), index=False)
