In [5]:
import os
import random
from pathlib import Path

In [6]:
from sklearn.model_selection import train_test_split

In [7]:
def split_eurosat(dataset_root, output_root, seed=369, train_ratio=0.5, val_ratio=0.2, test_ratio=0.3, max_images=2000):
    
    #convert String to Path Object
    dataset_root = Path(dataset_root)
    output_root = Path(output_root)

    #create mkdir
    output_root.mkdir(parents=True, exist_ok=True)
    
    #classes of EuroSAT
    classes = sorted([d.name for d in dataset_root.iterdir() if d.is_dir()])

    #initialize set files
    train_files, val_files, test_files  = [],[],[]

    random.seed(seed)

    #obtain images in every class
    for cls in classes:
        class_dir = dataset_root / cls
        # relative path
        files = sorted([str(f.relative_to(dataset_root.parent)) for f in class_dir.glob("*.jpg")])

        # schuffle the order of images
        random.shuffle(files)

        # hold 2000 images, because of the catagory Pasture
        if len(files) > max_images:
            files = files[:max_images]

        labels = [cls] * len(files)

        # train + temp
        train_f, temp_f, train_l, temp_l = train_test_split(
            files, labels, 
            test_size=(1 - train_ratio),
            stratify=labels,
            random_state=seed
        )

        # val + test from temp
        val_f, test_f, _, _ = train_test_split(
            temp_f, temp_l,
            test_size=test_ratio / (val_ratio + test_ratio),
            stratify=temp_l,
            random_state=seed
        )

        train_files.extend(train_f)
        val_files.extend(val_f)
        test_files.extend(test_f)

    assert len(set(train_files) & set(val_files)) == 0
    assert len(set(train_files) & set(test_files)) == 0
    assert len(set(val_files) & set(test_files)) == 0
    print("Splits are disjoint.")

    (output_root / "train.txt").write_text("\n".join(train_files))
    (output_root / "val.txt").write_text("\n".join(val_files))
    (output_root / "test.txt").write_text("\n".join(test_files))

    print("Files written to:", output_root)

    # if __name__ == "__main__":
    # split_eurosat(
    #     dataset_root="EuroSAT_MS",
    #     output_root="splits"
    # )
 

In [8]:
split_eurosat(
    dataset_root="EuroSAT_RGB",
    output_root="train_val_test")

Splits are disjoint.
Files written to: train_val_test
