# Identify bad images in training / validation sets

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import holoviews as hv
import janitor
import numpy as np
import pandas as pd
import panel as pn
from pytorch_hcs.vis import set_hv_defaults

from pybbbc import BBBC021

set_hv_defaults()

In [None]:
from pybbbc import BBBC021

bbbc021 = BBBC021(moa=[moa for moa in BBBC021.MOA if moa != "null"])

In [None]:
image_df = bbbc021.image_df
image_df

Absolute image indices so we can map these images back to the original, unfiltered BBBC021 dataset which includes null MoA images.

In [None]:
abs_image_idcs = bbbc021.image_df["image_idx"].values
abs_image_idcs

In [None]:
from holoviews.streams import Stream, param

In [None]:
# quality_df = pd.DataFrame(
#     dict(
#         image_idx=abs_image_idcs,
#         quality=pd.Categorical(len(abs_image_idcs)*['unclassified'], categories=["good", "bad", "unclassified"]),
#     )
# )

quality_df = pd.read_parquet("data/image_quality_classification.parquet")

quality_df

In [None]:
def make_layout(image_idx):
    image, metadata = bbbc021[image_idx]

    quality = quality_df.iloc[image_idx]["quality"]

    label = f"idx: {abs_image_idcs[image_idx]} | {quality}"

    plots = []

    cmaps = ["fire", "kg", "kb"]

    for channel_idx, im_channel in enumerate(image):
        plot = hv.Image(
            im_channel,
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label=f"{label} | {bbbc021.CHANNELS[channel_idx]}",
        ).opts(cmap=cmaps[channel_idx])
        plots.append(plot)

    plots.append(
        hv.RGB(
            image.transpose(1, 2, 0),
            bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
            label="Channel overlay",
        )
    )

    return hv.Layout(plots).cols(2)


previous_btn = pn.widgets.Button(name="Previous Image")
next_btn = pn.widgets.Button(name="Next Image")
good_btn = pn.widgets.Button(name="Good", button_type="success")
bad_btn = pn.widgets.Button(name="Bad", button_type="danger")

image_idx_slider = pn.widgets.IntSlider(
    name="image_idx", value=0, start=0, end=len(abs_image_idcs) - 1
)


def previous_callback(_):
    image_idx_slider.value -= 1


def next_callback(_):
    image_idx_slider.value += 1


def good_callback(_):
    cur_image_idx = image_idx_slider.value

    quality_df["quality"].iloc[cur_image_idx] = "good"

    image_idx_slider.value += 1


def bad_callback(_):
    cur_image_idx = image_idx_slider.value

    quality_df["quality"].iloc[cur_image_idx] = "bad"

    image_idx_slider.value += 1


previous_btn.on_click(previous_callback)
next_btn.on_click(next_callback)
good_btn.on_click(good_callback)
bad_btn.on_click(bad_callback)

pane = pn.Column(
    hv.DynamicMap(pn.bind(make_layout, image_idx_slider, watch=True)).opts(
        hv.opts.RGB(frame_width=550), hv.opts.Image(frame_width=550)
    ),
    image_idx_slider,
    pn.Row(
        previous_btn,
        next_btn,
        good_btn,
        bad_btn,
    ),
)

pane

In [None]:
# quality_df.to_parquet('data/image_quality_classification.parquet')

In [None]:
bad_idcs = quality_df.query('quality == "bad"')["image_idx"].values
bad_idcs