## Step 1: Library Imports

In [1]:
import os
import random
import json
from collections import defaultdict
import numpy as np
from sklearn.model_selection import train_test_split

## Step 2: Dataset Paths and Settings

In [2]:
# IRMAS mono training data root
DATASET_ROOT = r"E:\InstruNet-AI\data\IRMAS-TrainingData"

# Where we will store split file lists
SPLIT_META_DIR = r"E:\InstruNet-AI\data\splits"
os.makedirs(SPLIT_META_DIR, exist_ok=True)

RANDOM_SEED = 42
TEST_SIZE = 0.15
VAL_SIZE = 0.15

## Step 3: Scan Dataset and Collect File Paths

In [3]:
data_index = []  # (file_path, class_name)

for class_name in os.listdir(DATASET_ROOT):
    class_dir = os.path.join(DATASET_ROOT, class_name)
    if not os.path.isdir(class_dir):
        continue

    for file in os.listdir(class_dir):
        if file.lower().endswith(".wav"):
            file_path = os.path.join(class_dir, file)
            data_index.append((file_path, class_name))

print(f"Total mono samples found: {len(data_index)}")

Total mono samples found: 6705


## Step 4: Encode Labels Numerically

In [4]:
class_names = sorted(list(set(label for _, label in data_index)))
class_to_id = {cls: idx for idx, cls in enumerate(class_names)}
id_to_class = {idx: cls for cls, idx in class_to_id.items()}

print("Class mapping:")
for k, v in class_to_id.items():
    print(f"{k} -> {v}")

Class mapping:
cel -> 0
cla -> 1
flu -> 2
gac -> 3
gel -> 4
org -> 5
pia -> 6
sax -> 7
tru -> 8
vio -> 9
voi -> 10


In [5]:
with open(os.path.join(SPLIT_META_DIR, "label_map.json"), "w") as f:
    json.dump(class_to_id, f, indent=4)

## Step 5: Prepare Stratified Labels

In [6]:
file_paths = [fp for fp, _ in data_index]
labels = [class_to_id[label] for _, label in data_index]

## Step 6: Train / Temp Split (Stratified)

In [7]:
X_train, X_temp, y_train, y_temp = train_test_split(
    file_paths,
    labels,
    test_size=(VAL_SIZE + TEST_SIZE),
    stratify=labels,
    random_state=RANDOM_SEED
)

print(f"Train size: {len(X_train)}")
print(f"Temp size : {len(X_temp)}")

Train size: 4693
Temp size : 2012


## Step 7: Validation / Test Split (Stratified)

In [8]:
relative_test_size = TEST_SIZE / (VAL_SIZE + TEST_SIZE)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp,
    y_temp,
    test_size=relative_test_size,
    stratify=y_temp,
    random_state=RANDOM_SEED
)

print(f"Validation size: {len(X_val)}")
print(f"Test size      : {len(X_test)}")

Validation size: 1006
Test size      : 1006


## Step 8: Verify Class Distribution (Sanity Check)

In [9]:
def count_classes(labels, id_to_class):
    counts = defaultdict(int)
    for y in labels:
        counts[id_to_class[y]] += 1
    return dict(counts)

print("Train distribution:")
print(count_classes(y_train, id_to_class))

print("\nValidation distribution:")
print(count_classes(y_val, id_to_class))

print("\nTest distribution:")
print(count_classes(y_test, id_to_class))

Train distribution:
{'cla': 353, 'pia': 505, 'gac': 446, 'org': 477, 'vio': 406, 'sax': 438, 'cel': 272, 'gel': 532, 'tru': 404, 'flu': 316, 'voi': 544}

Validation distribution:
{'vio': 87, 'gel': 114, 'tru': 87, 'pia': 108, 'voi': 117, 'flu': 67, 'org': 102, 'gac': 96, 'cel': 58, 'sax': 94, 'cla': 76}

Test distribution:
{'pia': 108, 'voi': 117, 'sax': 94, 'cel': 58, 'org': 103, 'vio': 87, 'cla': 76, 'gel': 114, 'tru': 86, 'gac': 95, 'flu': 68}


## Step 9: Save Split Metadata

In [10]:
split_data = {
    "train": X_train,
    "val": X_val,
    "test": X_test
}

with open(os.path.join(SPLIT_META_DIR, "split_files.json"), "w") as f:
    json.dump(split_data, f, indent=2)

print("Split metadata saved successfully.")

Split metadata saved successfully.


## Step 10: Load split metadata

In [11]:
import json
import shutil

In [12]:
SPLIT_META_PATH = r"E:\InstruNet-AI\data\splits\split_files.json"

with open(SPLIT_META_PATH, "r") as f:
    split_data = json.load(f)

print(split_data.keys())  # train, val, test

dict_keys(['train', 'val', 'test'])


## Step 11: Define new dataset root

In [13]:
NEW_DATASET_ROOT = r"E:\InstruNet-AI\data\irmas_mono"

for split in ["train", "val", "test"]:
    os.makedirs(os.path.join(NEW_DATASET_ROOT, split), exist_ok=True)

## Step 12: Copy files into split folders

In [14]:
for split, file_list in split_data.items():
    print(f"Processing split: {split}")

    for file_path in file_list:
        class_name = os.path.basename(os.path.dirname(file_path))
        file_name = os.path.basename(file_path)

        dest_dir = os.path.join(NEW_DATASET_ROOT, split, class_name)
        os.makedirs(dest_dir, exist_ok=True)

        dest_path = os.path.join(dest_dir, file_name)

        # MOVE the file (use copy2 if you want a backup)
        shutil.copy2(file_path, dest_path)

Processing split: train
Processing split: val
Processing split: test


## Step 13: Verify the new structure

In [15]:
for split in ["train", "val", "test"]:
    split_dir = os.path.join(NEW_DATASET_ROOT, split)
    total = 0
    print(f"\n{split.upper()}")

    for cls in os.listdir(split_dir):
        cls_dir = os.path.join(split_dir, cls)
        n = len([f for f in os.listdir(cls_dir) if f.endswith(".wav")])
        total += n
        print(f"{cls}: {n}")

    print(f"Total {split}: {total}")


TRAIN
cel: 272
cla: 353
flu: 316
gac: 446
gel: 532
org: 477
pia: 505
sax: 438
tru: 404
vio: 406
voi: 544
Total train: 4693

VAL
cel: 58
cla: 76
flu: 67
gac: 96
gel: 114
org: 102
pia: 108
sax: 94
tru: 87
vio: 87
voi: 117
Total val: 1006

TEST
cel: 58
cla: 76
flu: 68
gac: 95
gel: 114
org: 103
pia: 108
sax: 94
tru: 86
vio: 87
voi: 117
Total test: 1006
