In [2]:
import os
import shutil
import random
from tqdm import tqdm

In [3]:
ROOT_PATH = "../../../data"
IMG_SRC_PATH = os.path.join(ROOT_PATH, "images/all")
LABEL_SRC_PATH = os.path.join(ROOT_PATH, "YOLO_labels/all")
VAL_NUM = 10000
TEST_NUM = 20000
random.seed(42)

In [4]:
def split_train_val():
    img_list = sorted(os.listdir(IMG_SRC_PATH))
    random.shuffle(img_list)

    val_img = img_list[:VAL_NUM]
    normal_test_img = img_list[VAL_NUM:VAL_NUM+TEST_NUM]

    return val_img, normal_test_img

def split_val(val_img):
    IMG_DST_PATH = os.path.join(ROOT_PATH, "images/val")
    LABEL_DST_PATH = os.path.join(ROOT_PATH, "YOLO_labels/val")

    os.makedirs(IMG_DST_PATH, exist_ok=True)
    os.makedirs(LABEL_DST_PATH, exist_ok=True)
    
    for f in tqdm(val_img, f"Splitting val"):
        img_src = os.path.join(IMG_SRC_PATH, f)
        label_src = os.path.join(LABEL_SRC_PATH, f.replace("jpg", "txt"))

        img_dst = os.path.join(IMG_DST_PATH, f)
        label_dst = os.path.join(LABEL_DST_PATH, f.replace("jpg", "txt"))

        shutil.move(img_src, img_dst)
        if os.path.exists(label_src):
            shutil.move(label_src, label_dst)
    
def split_normal_test(normal_test_img):
    IMG_DST_PATH = os.path.join(ROOT_PATH, "images/test/normal")
    LABEL_DST_PATH = os.path.join(ROOT_PATH, "YOLO_labels/test/normal")
    os.makedirs(IMG_DST_PATH, exist_ok=True)
    os.makedirs(LABEL_DST_PATH, exist_ok=True)

    for f in tqdm(normal_test_img, f"Processing normal test"):
        img_src = os.path.join(IMG_SRC_PATH, f)
        label_src = os.path.join(LABEL_SRC_PATH, f.replace("jpg", "txt"))

        img_dst = os.path.join(IMG_DST_PATH, f)
        label_dst = os.path.join(LABEL_DST_PATH, f.replace("jpg", "txt"))

        shutil.move(img_src, img_dst)
        if os.path.exists(label_src):
            shutil.move(label_src, label_dst)

In [5]:
val_img, normal_test_img = split_train_val()

In [6]:
split_val(val_img)

Splitting val: 100%|██████████| 10000/10000 [00:12<00:00, 788.12it/s]


In [7]:
split_normal_test(normal_test_img)

Processing normal test: 100%|██████████| 20000/20000 [00:30<00:00, 648.90it/s] 


In [8]:
with open("val.txt", "w") as f:
    for file in val_img:
        f.write(file+"\n")
with open ("normal_test.txt", "w") as f:
    for file in normal_test_img:
        f.write(file+"\n")