# Dataset Split

Split dataset into test, validation and training set.

An important point to make about the preprocessing is that
any preprocessing statistics (e.g. the data mean)
must only be computed on the training data.
(https://cs231n.github.io/neural-networks-2/)

The data is imbalanced.
Use http://contrib.scikit-learn.org/imbalanced-learn/index.html.

In [1]:
%matplotlib inline

In [2]:
import h5py
import numpy as np
import sklearn.utils
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [3]:
with h5py.File('data/data.hdf5') as f:
    X = f['X'][...]
    y = f['y'][...]

In [4]:
labels, counts = np.unique(y, return_counts=True)
for label, count in zip(labels, counts):
    print('label {}: {} samples'.format(label, count))

label 0: 5301 samples
label 1: 6103 samples
label 2: 1533 samples


In [5]:
N_LABELS = 3

In [6]:
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.1, stratify=y)
X_tr, X_val, y_tr, y_val = train_test_split(X_tr, y_tr, test_size=0.2, stratify=y_tr)
for name, (_, cnts) in zip(['training', 'validation', 'testing'], [
    np.unique(y_tr, return_counts=True), np.unique(y_val, return_counts=True), np.unique(y_te, return_counts=True)
]):
    print(name + ' set', cnts, sum(cnts), sep='\t')

training set	[3817 4393 1104]	9314
validation set	[ 954 1099  276]	2329
testing set	[530 611 153]	1294


In [7]:
# with h5py.File('data/data.hdf5') as f:
#     try:
#         del f['X_tr']
#         del f['X_val']
#         del f['X_te']
#         del f['y_tr']
#         del f['y_val']
#         del f['y_te']
#     except:
#         pass

#     f.create_dataset('X_tr', X_tr.shape, dtype=np.float64)[...] = X_tr
#     f.create_dataset('X_val', X_val.shape, dtype=np.float64)[...] = X_val
#     f.create_dataset('X_te', X_te.shape, dtype=np.float64)[...] = X_te

#     f.create_dataset('y_tr', y_tr.shape, dtype=np.int8)[...] = y_tr
#     f.create_dataset('y_val', y_val.shape, dtype=np.int8)[...] = y_val
#     f.create_dataset('y_te', y_te.shape, dtype=np.int8)[...] = y_te