In [1]:
import os
from pathlib import Path
from hcmus.core import appconfig
from hcmus.lbs import LabelStudioConnector

[32m2025-06-06 21:46:49.507[0m | [1mINFO    [0m | [36mhcmus.core.appconfig[0m:[36m<module>[0m:[36m7[0m - [1mLoad DotEnv: True[0m


In [2]:
lsb_connector = LabelStudioConnector(
    url=appconfig.LABEL_STUDIO_URL,
    api_key=appconfig.LABEL_STUDIO_API_KEY,
    project_id=appconfig.LABEL_STUDIO_PROJECT_MAPPING["train"],
    temp_dir=appconfig.LABEL_STUDIO_TEMP_DIR
)

In [3]:
tasks = lsb_connector.get_tasks()
labels = lsb_connector.extract_labels(tasks)
dataset = lsb_connector.download_dataset(tasks, labels)
dataset = [x for x in dataset if x.get("target").get("labels")]

[32m2025-06-06 21:46:51.116[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m125[0m - [1mNew `page_to` applied: 23[0m
Loading tasks: 100%|██████████| 23/23 [00:12<00:00,  1.91it/s]
Downloading images: 100%|██████████| 2298/2298 [00:12<00:00, 182.21it/s] 


In [4]:
def build_classification_dataset(dataset, labels):
    ret_dataset = []
    unknown_idx = labels["unknown"]
    prioritized_label = "8935136865648"
    prioritized_idx = [v for k, v in labels.items() if k.startswith(prioritized_label)][0]

    for item in dataset:
        target_labels = list(set(item.get("target").get("labels")))
        if unknown_idx in labels:
            labels.remove(unknown_idx)

        final_label = -1
        if len(target_labels) == 1:
            final_label = target_labels[0]

        if len(target_labels) > 1:
            for label in item.get("target").get("labels"):
                final_label = label
                if prioritized_idx in target_labels:
                    final_label = prioritized_idx
                    break

        ret_dataset.append({
            "image": item.get("image"),
            "label": final_label
        })
    return ret_dataset

In [5]:
def select_labels_m_samples(cls_dataset, m_samples=12):
    import pandas as pd
    label_stats = [x.get("label") for x in cls_dataset]
    df = pd.DataFrame(label_stats, columns=["label"])
    df = df.groupby("label")["label"].agg(["count"])
    df = df.sort_values("count")
    df = df[df["count"] >= m_samples]
    df = df.reset_index()
    selected_label_idx = df["label"].tolist()
    return selected_label_idx

In [6]:
def filter_cls_dataset(cls_dataset, selected_labels):
    ret_dataset = []
    for item in cls_dataset:
        if item.get("label") in selected_labels:
            ret_dataset.append(item)
    return ret_dataset

In [7]:
cls_dataset = build_classification_dataset(dataset, labels)
selected_labels = select_labels_m_samples(cls_dataset, 8)
filtered_dataset = filter_cls_dataset(cls_dataset, selected_labels)

In [8]:
idx2label = {v: k for k, v in labels.items()}

In [9]:
dataset_dir = Path("dataset")
dataset_dir.mkdir(exist_ok=True)

for item in filtered_dataset:
    label = item.get("label")
    path = item.get("image")
    label_name = idx2label[label]
    label_name = label_name.replace("/", "-")
    label_name = label_name.replace("'", "-")
    label_dir = dataset_dir.joinpath(label_name)
    label_dir.mkdir(exist_ok=True)
    return_code = os.system(f"cp -f '{path}' '{label_dir}'")
    if return_code != 0:
        print(label_name)
        break