# Split data

In [None]:
from sklearn.model_selection import train_test_split

def split_data(data, labels, val_split_perc, test_split_perc=0.0, random_state=7):
    if test_split_perc > 0:
        # Step 1: Split off the test set
        x_temp, x_test, y_temp, y_test = train_test_split(
            data, labels, test_size=test_split_perc, stratify=labels, random_state=random_state
        )
        # Step 2: Split validation from the remaining data
        val_split_relative = val_split_perc / (1 - test_split_perc)
        x_train, x_val, y_train, y_val = train_test_split(
            x_temp, y_temp, test_size=val_split_relative, stratify=y_temp, random_state=random_state
        )
    else:
        # Only split train and validation
        x_train, x_val, y_train, y_val = train_test_split(
            data, labels, test_size=val_split_perc, stratify=labels, random_state=random_state
        )
        x_test, y_test = None, None

    print(f"Train count: {len(x_train)}")
    print(f"Validation count: {len(x_val)}")
    print(f"Test count: {len(x_test) if x_test is not None else 0}")

    return x_train, y_train, x_val, y_val, x_test, y_test