In [2]:
import os
import shutil
import random
import json
from pathlib import Path

def split_dataset(src_dir, dst_dir, split_config_file, make_valid=True, seed=42, version="main"):
    """Split the dataset into train and valid directories based on case type ratios."""
    # Set random seed for reproducibility
    random.seed(seed)

    # Define directories
    images_dir = os.path.join(src_dir, 'images')
    labels_dir = os.path.join(src_dir, 'labels')

    # Load split configuration from JSON file
    with open(split_config_file, 'r', encoding='utf-8') as f:
        split_config = json.load(f)

    # Prepare the directories for split
    train_images_dir = Path(dst_dir) / f"train_{version}" / "images"
    valid_images_dir = Path(dst_dir) / f"valid_{version}" / "images"
    train_labels_dir = Path(dst_dir) / f"train_{version}" / "labels"
    valid_labels_dir = Path(dst_dir) / f"valid_{version}" / "labels"

    # Create the directories if they don't exist
    for dir_path in [train_images_dir, valid_images_dir, train_labels_dir, valid_labels_dir]:
        dir_path.mkdir(parents=True, exist_ok=True)

    # Dictionary to hold the file paths by case type
    case_type_files = {key: {'images': [], 'labels': []} for key in split_config}

    # Collect all files from images and labels directories
    images = list(Path(images_dir).glob('*.jpg'))
    labels = list(Path(labels_dir).glob('*.txt'))

    # Group files by case type based on the name prefix
    for image in images:
        for case_type in split_config.keys():  # Iterate over all keys in split_config
            if image.name.startswith(case_type):  # Check if the image name starts with the case_type
                case_type_files[case_type]['images'].append(image)
                break  # Once found, no need to check further case types

    for label in labels:
        for case_type in split_config.keys():  # Iterate over all keys in split_config
            if label.name.startswith(case_type):  # Check if the label name starts with the case_type
                case_type_files[case_type]['labels'].append(label)
                break  # Once found, no need to check further case types

    # Split each case type's files
    for case_type, files in case_type_files.items():
        ratio = split_config[case_type].get('ratio', 1)  # Default to 1 (all training, no validation)
        image_files = files['images']
        label_files = files['labels']

        # Skip empty case types
        if len(image_files) == 0 or len(label_files) == 0:
            print(f"Warning: No data for case type {case_type}. Skipping...")
            continue

        if len(image_files) != len(label_files):
            raise ValueError(f"Mismatch between image and label files for case type: {case_type}")
        image_files.sort(key=lambda x: x.stem)  # Sort by the file name (without extension)
        label_files.sort(key=lambda x: x.stem)
        # Shuffle and split files randomly
        combined = list(zip(image_files, label_files))
        random.shuffle(combined)
        image_files, label_files = zip(*combined)

        # Determine split index
        split_index = int(len(image_files) * ratio)

        # Split into training and validation
        train_images = image_files[:split_index]
        valid_images = image_files[split_index:]
        train_labels = label_files[:split_index]
        valid_labels = label_files[split_index:]

        # Move files to the corresponding directories
        for train_image, train_label in zip(train_images, train_labels):
            shutil.copy(train_image, train_images_dir / train_image.name)
            shutil.copy(train_label, train_labels_dir / train_label.name)

        if make_valid:
            for valid_image, valid_label in zip(valid_images, valid_labels):
                shutil.copy(valid_image, valid_images_dir / valid_image.name)
                shutil.copy(valid_label, valid_labels_dir / valid_label.name)

    print(f"Data split into {train_images_dir} and {valid_images_dir if make_valid else 'no valid set'}.")


In [None]:
# roboflow project -> train_main, valid_main

split_dataset(src_dir='/Users/jjookim/Projects/AIForce/datasets/final_dataset/train',\
              dst_dir='/Users/jjookim/Projects/AIForce/datasets/all_data',\
              split_config_file = '/Users/jjookim/Projects/AIForce/datasets/jsons/all_data.json',\
              make_valid=True, seed=42, version="main")

In [None]:
# train_main -> train_v#

split_dataset(src_dir='/Users/jjookim/Projects/AIForce/datasets/all_data/train_main',\
              dst_dir='/Users/jjookim/Projects/AIForce/datasets/all_data',\
              split_config_file = '/Users/jjookim/Projects/AIForce/datasets/jsons/train_v1.json',\
              make_valid=False, seed=42, version="v1")

In [None]:
# valid_main -> valid_v#
# Warning: need to change the file process after using this
split_dataset(src_dir='/Users/jjookim/Projects/AIForce/datasets/all_data/train_main',\
              dst_dir='/Users/jjookim/Projects/AIForce/datasets/all_data',\
              split_config_file = '/Users/jjookim/Projects/AIForce/datasets/jsons/train_v1.json',\
              make_valid=False, seed=42, version="version name should be complicated and the train is the using valid in this case think brother")