# Image Data Labeler
> labeling image data

In [1]:
# default_exp turk.image

In [173]:
# export
import json
from pathlib import Path
from forgebox.files import file_detail
from forgebox.html import DOM
import pandas as pd
from typing import List, Dict
from ipywidgets import interact, interact_manual, Button, SelectMultiple, \
    Output, HBox, VBox
from PIL import Image as PILImage
from ipywidgets import Image as ImageWidget
import logging
from tqdm.notebook import tqdm

## Dataset for testing

For contributer of this library, you can use the default testing images in ../test/img, or you can uncomment the following and download more data you like

In [174]:
# !pip install -q jmd_imagescraper

In [175]:
# from jmd_imagescraper.core import duckduckgo_search, ImgSize

In [176]:
# duckduckgo_search("../test/img", "Nature", "nature", max_results=20)

In [177]:
!du -sh ../test

996K	../test


## Labeler

In [178]:
file_detail("../test")["path"][0]

'/Users/salvor/github/unpackai/nbs/../test/img/Nature/007_9554a747.jpg'

In [188]:
# export
class ImageLabeler:
    def __init__(self,
                 image_folder: Path,
                 formats: List[str] = ["jpg", "jpeg", "png", "bmp"],
                 ):
        """
        path: Path, a folder full of images
        formats: a list of allowed formats
        """
        self.image_folder = image_folder
        self.file_df = file_detail(image_folder)
        self.filter_image(formats)
        self.output = Output()

    def __repr__(self):
        return f"{self.__class__.__name__} on [{self.image_folder}({len(self.image_df)})], see labeler.image_df"

    def filter_image(
        self,
        formats: List[str] = ["jpg", "jpeg", "png", "bmp"]
    ) -> pd.DataFrame:
        """
        Filter the file dataframe to image only files
        assign image_df attribute to the object
        """
        formats += list(map(lambda x: x.upper(), formats))
        self.image_df = self.file_df[self.file_df.file_type.isin(
            formats)].reset_index(drop=True)
        return self.image_df

    def __call__(self, *args, **kwargs):
        raise NotImplementedError(
            f"Please use SingleClassImageLabeler, or MultiClassImageLabeler")

    @property
    def identifier(self):
        return self.progress['meta']['identifier']

    def save_progress(
        self,
        location: Path = Path("."),
        filename="unpackai_imglbl.json"
    ):
        """
        Save the progress to location/filename
        default save to current directory ./unpackai_imglbl.json
        """
        with open(location/filename, "w") as f:
            f.write(json.dumps(self.progress))
        logging.info(f"Progess Saved to {location/filename}")

    @classmethod
    def load_saved(cls, filepath="./unpackai_imglbl.json"):
        """
        Load saved labeler's progress
        """
        with open(filepath, "r") as f:
            progress = json.loads(f.read())
        image_folder = progress['meta']['image_folder']
        obj = cls(image_folder)
        obj.progress = progress
        return obj

    def new_progress(self, labels: List[str], identifier: str = "path"):
        self.progress = dict(
            meta=dict(
                image_folder=self.image_folder,
                labels=labels,
                identifier=identifier,
            ),
            data=dict((str(k), None) for k in list(self.image_df[identifier]))
        )

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.labels = labels
        if hasattr(self, "progress") == False:
            self.new_progress(labels)

        for k, v in tqdm(self.progress['data'].items(), leave=False):
            if v is None:
                yield k

    def __getitem__(self, key):
        """
        render a page according to key
        """
        row = self.get_row_data(key)
        self.output.clear_output()
        with self.output:
            with PILImage.open(
                    row[self.identifier]).resize((512, 512)) as img:
                display(img)
            label_btns = self.create_label_btns(row)
            key = row[self.identifier]
            
            # current labeled label
            current = self.progress['data'][key]
            if current is not None:
                DOM(f"Current Label:{current}", "h5")()
                
            # navigation buttons
            nav_btns = list(btn for btn in [self.create_show_last_btn(key),
                                            self.create_show_next_btn(key),
                                            self.create_save_btn(),
                                           ] if btn is not None)
            display(VBox([label_btns,
                          HBox(nav_btns)
                          ]))

    def get_row_data(self, key):
        identifier = self.identifier
        row = dict(self.image_df.query(
            f"{identifier}=='{key}'").to_dict(orient='records')[0])
        return row
    
    def render_page(self):
        """
        Render a new page
        """
        try:
            key = next(self.gen)
        except StopIteration:
            self.save_progress()
            self.done_page()
            return
        self[key]

    def create_show_last_btn(self, key):
        keys = list(self.progress["data"].keys())
        idx = keys.index(str(key))
        if idx == 0:
            return None
        last_key = keys[idx-1]

        def show_last_click():
            self[last_key]
        btn = Button(description="Last", icon="arrow-left")
        btn.click = show_last_click
        return btn

    def create_show_next_btn(self, key):
        keys = list(self.progress["data"].keys())
        idx = keys.index(str(key))
        if idx >= len(self.progress["data"])-1:
            return None
        next_key = keys[idx+1]

        def show_next_click():
            self[next_key]
        btn = Button(description="Next", icon="arrow-right")
        btn.click = show_next_click
        return btn
    
    def create_save_btn(self):
        btn = Button(description="Save", icon='save')
        btn.click = self.save_progress
        return btn


class SingleClassImageLabeler(ImageLabeler):
    def __init__(self, image_folder: Path):
        """
        path: Path, a folder full of images
        """
        super().__init__(image_folder)

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.gen = super().__call__(labels)

        self.render_page()

        display(self.output)

    def create_label_btns(self, row):
        btns = []
        for label in self.labels:
            btn = Button(description=label, icon="check-circle")

            def callback():
                k = row[self.identifier]
                self.progress["data"][str(k)] = label
                self.render_page()
            btn.click = callback
            btns.append(btn)

        return HBox(btns)

    def done_page(self):
        self.output.clear_output()
        with self.output:
            DOM("That's the end of the iteration", "h3")()


class MultiClassImageLabeler(ImageLabeler):
    def __init__(self, image_folder: Path):
        """
        path: Path, a folder full of images
        """
        super().__init__(image_folder)

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.gen = super().__call__(labels)
        DOM("press Command(mac) or Ctrl(win/linux) to select multiple","h4")()
        self.render_page()
        display(self.output)

    def create_label_btns(self, row):
        btns = []
        select = SelectMultiple(options=self.labels)
        btn = Button(description="Okay!", icon="check-circle")
        
        def callback():
            k = row[self.identifier]
            self.progress["data"][str(k)] = list(select.value)
            self.render_page()

        btn.click = callback

        return HBox([select, btn])

    def done_page(self):
        self.output.clear_output()
        with self.output:
            DOM("That's the end of the iteration", "h3")()

In [184]:
slabel = SingleClassImageLabeler("../test")
slabel

SingleClassImageLabeler on [../test(20)], see labeler.image_df

In [181]:
slabel()

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Output()

In [186]:
slabel = SingleClassImageLabeler.load_saved()

In [187]:
slabel()

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Output()

In [189]:
mlabel = MultiClassImageLabeler("../test")
mlabel

MultiClassImageLabeler on [../test(20)], see labeler.image_df

In [190]:
mlabel(labels=["spring", "summer", "autumn", "winter"])

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Output()