In [None]:
import logging
import random
import shutil
from pathlib import Path

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
)

In [None]:
input_folder_img = Path(f"data\IMG")
input_folder_mask = Path(f"data\MASK")

output_folder = Path(f"results\dataset_distributed")

eval_percent = 0.35
test_percent = 0.05

In [None]:
output_train = output_folder / "TRAIN"
output_eval = output_folder / "EVAL"
output_test = output_folder / "TEST"

masks = list(input_folder_mask.glob("*.tif"))

val = int(len(masks) * eval_percent)
test = int(len(masks) * test_percent)
train = len(masks) - val - test

logging.info(f"Train: {train}")
logging.info(f"Eval: {val}")
logging.info(f"Test: {test}")

In [None]:
train_tiles = random.sample(masks, train)
eval_tiles = random.sample(list(set(masks) - set(train_tiles)), val)
test_tiles = list(set(masks) - set(train_tiles) - set(eval_tiles))

logging.info(f"Train tiles: {len(train_tiles)}")
logging.info(f"Eval tiles: {len(eval_tiles)}")
logging.info(f"Test tiles: {len(test_tiles)}")

In [None]:
for mask in train_tiles:
    (output_train / "MASK").mkdir(parents=True, exist_ok=True)
    (output_train / "IMG").mkdir(parents=True, exist_ok=True)

    shutil.copy(mask, output_train / "MASK" / mask.name)

    image = input_folder_img / mask.name.replace("MASK", "IMG")
    shutil.copy(image, output_train / "IMG" / image.name)

for mask in eval_tiles:
    (output_eval / "MASK").mkdir(parents=True, exist_ok=True)
    (output_eval / "IMG").mkdir(parents=True, exist_ok=True)

    shutil.copy(mask, output_eval / "MASK" / mask.name)

    image = input_folder_img / mask.name.replace("MASK", "IMG")
    shutil.copy(image, output_eval / "IMG" / image.name)

for mask in test_tiles:
    (output_test / "MASK").mkdir(parents=True, exist_ok=True)
    (output_test / "IMG").mkdir(parents=True, exist_ok=True)

    shutil.copy(mask, output_test / "MASK" / mask.name)

    image = input_folder_img / mask.name.replace("MASK", "IMG")
    shutil.copy(image, output_test / "IMG" / image.name)

logging.info("Dataset distributed successfully.")