In [8]:
# !pip install ipywidgets
# !pip install ipyevents

In [9]:
import os
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

import ipywidgets as widgets
from ipyevents import Event
from ipywidgets import GridspecLayout
import traceback
import json
import glob

In [10]:
def create_label_index(root_folder: str) -> dict:
    label_folders = [dn for dn in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, dn))]
    label_index = {"classes": label_folders, "data": {}, "meta_data": {"last_task_id": 0}}

    for idx, dn in enumerate(label_folders):
        folder = os.path.join(root_folder, dn)
        fns = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(("png", "jpg", "jpeg", "JPG", "JPEG"))]
        for fn in fns:
            label_index["data"][fn] = {"label": idx, "checked": False}

    return label_index

def initialize_label_index(root_folder: str) -> dict:
    index_fn = os.path.join(root_folder, "labels.json")
    
    if os.path.exists(index_fn):
        with open(index_fn, "r") as f:
            label_index = json.load(f)
    else:
        label_index = create_label_index(root_folder)
        save_label_index(root_folder, label_index)
    return label_index

def save_label_index(root_folder: str, label_index: dict) -> None:
    index_fn = os.path.join(root_folder, "labels.json")
    with open(index_fn, "w") as f:
          json.dump(label_index, f)

In [25]:
IMAGE_BORDER_WIDTH = 6
N_IMAGES_H, N_IMAGES_V = 5, 5

IMAGE_SIZE = 125
MARGIN = 5

N_IMAGES_PER_TASK = N_IMAGES_H * N_IMAGES_V
IMAGE_BORDER_STYLES = {
    "-1_normal": f"{IMAGE_BORDER_WIDTH}px solid black",
    
    "0_normal": f"{IMAGE_BORDER_WIDTH}px solid #A93226",
    "0_hover": f"{IMAGE_BORDER_WIDTH}px solid #CB4335",

    "1_normal": f"{IMAGE_BORDER_WIDTH}px solid #2471A3",
    "1_hover": f"{IMAGE_BORDER_WIDTH}px solid #2E86C1",

    "2_normal": f"{IMAGE_BORDER_WIDTH}px solid #229954",
    "2_hover": f"{IMAGE_BORDER_WIDTH}px solid #28B463",
}

In [26]:
def initialize_image(image_size: int) -> widgets.Box:
    image = widgets.Image(
        value=bytes(),
        format='jpeg',
    )
    
    box = widgets.Box([image], height=image_size, width=image_size)
    box.meta_data = {}
    box.meta_data["file_name"] = None
    box.meta_data["label"] = None

    box.layout.border = ""
    
    image.layout.width = "100%"
    image.layout.height = "100%"
    image.layout.object_fit = "contain"
    image.layout.overflow = "hidden"

    box.layout.width = f"{image_size}px"
    box.layout.height = f"{image_size}px"

    def handle_event(target, event):
        lines = ['{}: {}'.format(k, v) for k, v in event.items()]
        try:
            label = target.meta_data["label"]
            if event["type"] == "click":
                label = target.meta_data["label"] = (label + 1) % 3
    
                target.layout.border = IMAGE_BORDER_STYLES[f"{label}_hover"]
            elif event["type"] == "mouseenter":
                target.layout.border = IMAGE_BORDER_STYLES[f"{label}_hover"]
            elif event["type"] == "mouseleave":
                target.layout.border = IMAGE_BORDER_STYLES[f"{label}_normal"]
            lines.append( "{}: {}".format("label", label))
            lines.append("{}: {}".format("style", image.layout.border))
        except Exception as ex:
            info.value = traceback.format_exc()
    
    box_event = Event(source=box, watched_events=['click', 'mouseenter', 'mouseleave'])
    box_event.on_dom_event(lambda e: handle_event(box, e))

    return box


def update_task_grid(grid) -> None:
    try:
        for idx in range(N_IMAGES_PER_TASK):
            i = idx // N_IMAGES_H
            j = idx % N_IMAGES_H
            
            fn = grid[i, j].meta_data["file_name"]
            label = grid[i, j].meta_data["label"]
            if fn is None:
                continue

            label_index["data"][fn]["label"] = label
            label_index["data"][fn]["checked"] = True
            info.value = "updated"  + str(label)
        label_index["meta_data"]["last_task_id"] = grid.meta_data["current_task_id"]
        save_label_index(data_root_folder, label_index)

        current_task_id = grid.meta_data["current_task_id"]
        
        for idx, fn in enumerate(image_paths[current_task_id * N_IMAGES_PER_TASK:(current_task_id + 1) * N_IMAGES_PER_TASK]):
            i = idx // N_IMAGES_H
            j = idx % N_IMAGES_H
            grid[i, j].layout.visibility = "visible"
            with open(fn, "rb") as f:
                image = f.read()
            grid[i, j].children[0].value = image
            label = label_index["data"][fn]["label"]
            
            grid[i, j].meta_data["label"] = label
            grid[i, j].meta_data["file_name"] = fn
            grid[i, j].layout.border = IMAGE_BORDER_STYLES[f"{label}_normal"]
        for idx in range(idx + 1, N_IMAGES_PER_TASK):
            i = idx // N_IMAGES_H
            j = idx % N_IMAGES_H
            grid[i, j].layout.visibility = "hidden"
            grid[i, j].meta_data["label"] = None
            grid[i, j].meta_data["file_name"] = None
    except:
        info.value = traceback.format_exc()    


def goto_next_task(grid):
    grid.meta_data["current_task_id"] = (grid.meta_data["current_task_id"] + 1) % grid.meta_data["n_tasks"]
    update_task_grid(grid)


def goto_previous_task(grid):
    grid.meta_data["current_task_id"] = (grid.meta_data["current_task_id"] - 1) % grid.meta_data["n_tasks"]
    update_task_grid(grid)

In [36]:
data_root_folder = ""  # format .../dataset/0/, .../dataset/1/, .../dataset/2/.

label_index = initialize_label_index(data_root_folder)

In [37]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [38]:
def grid_on_keyup(grid, event):
    try:
        if event["type"] != "keyup":
            return
        if str(event["key"]) == "ArrowLeft":
            goto_previous_task(grid)
        elif str(event["key"]) == "ArrowRight":
            goto_next_task(grid)
    except:
        info.value = "Exception: " + traceback.format_exc()

n_tasks = int(np.ceil(len(label_index["data"]) / N_IMAGES_PER_TASK))
image_paths = list(label_index["data"].keys())

grid = GridspecLayout(N_IMAGES_V, N_IMAGES_H, height=f"{IMAGE_SIZE * N_IMAGES_V + MARGIN * (N_IMAGES_V - 1)}px", width=f"{IMAGE_SIZE * N_IMAGES_H + MARGIN * (N_IMAGES_H - 1)}px")
grid.meta_data = {}
grid.meta_data["n_tasks"] = n_tasks
grid.meta_data["current_task_id"] = label_index["meta_data"]["last_task_id"]
grid_event = Event(source=grid, watched_events=['keyup'])
grid_event.on_dom_event(lambda e: grid_on_keyup(grid, e))

info = widgets.HTML("")
for idx in range(N_IMAGES_PER_TASK):
    i = idx // N_IMAGES_H
    j = idx % N_IMAGES_H
    grid[i, j] = initialize_image(IMAGE_SIZE)

update_task_grid(grid)

display(grid, info)

GridspecLayout(children=(Box(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\…

HTML(value='')