In [1]:
import os
import numpy as np
from glob import glob
from sklearn.model_selection import train_test_split

In [2]:
# config
data_dir = '../results/stacked_k8_dataset'
output_dir = '../results/split_k8_cube_lists'
random_seed = 28
os.makedirs(output_dir, exist_ok=True)

# find files
X_paths = sorted(glob(os.path.join(data_dir, '*_X.npy')))
y_paths = sorted(glob(os.path.join(data_dir, '*_y.npy')))
id_paths = sorted(glob(os.path.join(data_dir, '*_ids.npy')))

cube_names = [os.path.basename(p).replace('_X.npy', '') for p in X_paths]
print(f"Found {len(cube_names)} cubes")

Found 137 cubes


In [3]:
# split cubes
np.random.seed(random_seed)
shuffled = np.random.permutation(cube_names)

train_names, test_names = train_test_split(shuffled, test_size=0.2, random_state=random_seed)
train_names, val_names = train_test_split(train_names, test_size=0.25, random_state=random_seed)

print("\n Split summary:")
print(f"Train cubes: {len(train_names)}")
print(f"Val cubes:   {len(val_names)}")
print(f"Test cubes:  {len(test_names)}")


 Split summary:
Train cubes: 81
Val cubes:   28
Test cubes:  28


In [4]:
# =save lists
np.save(os.path.join(output_dir, 'train_cubes.npy'), train_names)
np.save(os.path.join(output_dir, 'val_cubes.npy'), val_names)
np.save(os.path.join(output_dir, 'test_cubes.npy'), test_names)

print("\n Cube lists saved:")
print("train_cubes.npy, val_cubes.npy, test_cubes.npy")


 Cube lists saved:
train_cubes.npy, val_cubes.npy, test_cubes.npy


In [6]:
# count pixels per set
train_px = val_px = test_px = 0

for i, name in enumerate(cube_names):
    y = np.load(os.path.join(data_dir, f'{name}_y.npy'))
    if name in train_names:
        train_px += len(y)
    elif name in val_names:
        val_px += len(y)
    elif name in test_names:
        test_px += len(y)

print("\n Final pixel counts:")
print(f"Train: {train_px:,}")
print(f"Val:   {val_px:,}")
print(f"Test:  {test_px:,}")


 Final pixel counts:
Train: 14,527,694
Val:   4,906,452
Test:  4,181,618


In [8]:
print("\n Test cubes:")
for name in sorted(test_names):
    print("-", name)

print("\n Validation cubes:")
for name in sorted(val_names):
    print("-", name)



 Test cubes:
- frt00007bc8_07_if166j_mtr3
- frt00008389_07_if166j_mtr3
- frt000093be_07_if166j_mtr3
- frt00009c31_07_if166j_mtr3
- frt00009c6a_07_if166j_mtr3
- frt0000bec0_07_if165j_mtr3
- frt0000bfd1_07_if166j_mtr3
- frt0000c202_07_if165j_mtr3
- frt0000c62b_07_if166j_mtr3
- frt0000c968_07_if166j_mtr3
- frt0000d6d6_07_if166j_mtr3
- frt00011d4c_07_if166j_mtr3
- frt0001487a_07_if166j_mtr3
- frt000161ef_07_if167j_mtr3
- frt00016655_07_if166j_mtr3
- frt00016a73_07_if166j_mtr3
- frt00016ed6_07_if165j_mtr3
- frt00017103_07_if165j_mtr3
- frt000174f4_07_if166j_mtr3
- frt0001fd76_07_if166j_mtr3
- frt00021da6_07_if166j_mtr3
- frt00023370_00_if166j_mtr3
- frt00024c1a_07_if165j_mtr3
- hrl0000cc16_07_if184j_mtr3
- hrl000116c6_07_if183j_mtr3
- hrl0001b769_07_if183j_mtr3
- hrs00011c01_07_if175j_mtr3
- hrs00012aa7_07_if175j_mtr3

 Validation cubes:
- frt00003e12_07_if166j_mtr3
- frt000047a3_07_if166j_mtr3
- frt00005c5e_07_if166j_mtr3
- frt0000805f_07_if166j_mtr3
- frt0000871c_07_if166j_mtr3
- frt0000