# Dataset Splitting #

Reads through a dataset folder and creates a set of train/val/test folders using the input labels as stratification
- Binary classification 
- 5 Folds cross validation 

## Imports ##

In [1]:
import os
import shutil
import numpy as np
from sklearn.model_selection import KFold

## Inputs ##

In [9]:
original_dir = "D://Rui//NIDACT2 Transplant ECGs//Multiclass_HTx_Rejection_Dataset_1(300_300)//Binary_0vs1-2R"
save_dir = original_dir + ' (5-Folds)'

target_labels = ['0', "1-2R"]
nClasses = len(target_labels)


## Binary classification

### Create New Split Folders (If needed)

In [9]:
# os.mkdir(save_dir)

train_path = save_dir + '//train//'
os.mkdir(train_path)
for class_name in target_labels:
    os.mkdir(train_path + class_name)

val_path = save_dir + '//val//'
os.mkdir(val_path)
for class_name in target_labels:
    os.mkdir(val_path + class_name)

test_path = save_dir + '//test//'
os.mkdir(test_path)
for class_name in target_labels:
    os.mkdir(test_path + class_name)


### Main Script ##

In [10]:
# Splitting into 60:20:20 splits based on modulus 5 of sample index
train_labels = []
val_labels = []
test_labels = []

class_idx = 0
for path, folder, files in os.walk(original_dir):
    if not(files):
        classes = target_labels
        dataset_path = path
        continue
    
    for file_idx, filename in enumerate(files):
        original_filepath = dataset_path + '//' + classes[class_idx] + '//' + filename
        label = np.zeros((len(classes)))
        label[class_idx] = 1
        match (file_idx % 5):
            case 0|1|2:
                train_labels.append(label)
                new_path = train_path + classes[class_idx] + '//' + filename
            case 3:
                val_labels.append(label)
                new_path = val_path + classes[class_idx]+ '//' + filename
            case 4:
                test_labels.append(label)
                new_path = test_path + classes[class_idx] + '//' + filename

        shutil.copyfile(original_filepath, new_path)
    class_idx += 1
    

## 5 Folds validation 

In [11]:
os.mkdir(save_dir)

fold_dir = []

number_of_fold = [1, 2, 3, 4, 5] 
for fold_number in number_of_fold: 
    os.mkdir(save_dir + "//" + str(fold_number))
    os.mkdir(save_dir + "//" + str(fold_number) + "//" + target_labels[0])
    os.mkdir(save_dir + "//" + str(fold_number) + "//" + target_labels[1])
    fold_dir.append (save_dir + "//" + str(fold_number))


In [12]:
filepaths = []
labels = []

class_idx = 0
for path, folder, files in os.walk(original_dir):
    if not(files):
        classes = target_labels
        dataset_path = path
        continue
    
    for file_idx, filename in enumerate(files):
        original_filepath = dataset_path + '//' + classes[class_idx] + '//' + filename
        label = np.zeros((len(classes)))
        label[class_idx] = 1
        match (file_idx % 5):
            case 0:
                labels.append(label)
                new_path = fold_dir[0] + "//" + classes[class_idx] + '//' + filename
            case 1:
                labels.append(label)
                new_path = fold_dir[1] + "//" + classes[class_idx] + '//' + filename
            case 2:
                labels.append(label)
                new_path = fold_dir[2] + "//" + classes[class_idx] + '//' + filename
            case 3:
                labels.append(label)
                new_path = fold_dir[3] + "//" + classes[class_idx]+ '//' + filename
            case 4:
                labels.append(label)
                new_path = fold_dir[4] + "//" + classes[class_idx] + '//' + filename

        shutil.copyfile(original_filepath, new_path)
    class_idx += 1