In [18]:
import numpy as np

In [39]:
def splitData(data, label=None, ratio=[0.8, 0.1]):
    '''
    Split the data as training, testing, validation datasets with input ratio.
    The ratio array presents the ratios of training and testing dataset.
    If the summation is less than 1, it means that the rest ratio is for validation dataset.
    Args:
        data (numpy ndarray): the original dataset, it will be split by the 1st dimension
        label (numpy array): the pairs containing {channel: adjust value}
    Returns:
        the dataframe with the adjusted test values
     
    '''
    # copy and shuffle the data
    data = data.copy()
    numInst = data.shape[0]
    indices = np.arange(numInst)
    np.random.shuffle(indices)
    
    # determine the indices to split the data
    indTrain = indices[:round(numInst*ratio[0])]
    indTest = indices[round(numInst*ratio[0]):round(numInst*sum(ratio))]
    indValid = indices[round(numInst*sum(ratio)):]
    
    # determine the datasets with the indices
    dataTrain = data[indTrain]
    dataTest = data[indTest]
    dataValid = data[indValid]
    
    # return the split datasets directly if there is no label 
    if label is None:
        return dataTrain, dataTest, dataValid
    
    # split labels with the same indices if there ae labels
    else:
        label = label.copy()
        labTrain = label[indTrain]
        labTest = label[indTest]
        labValid = label[indValid]
        return (dataTrain, labTrain), (dataTest, labTest), (dataValid, labValid)

In [41]:
x = np.array([[1, 2],
              [2, 3],
              [3, 4],
              [4, 5]])
y = np.array([0, 1, 2, 3])
print(splitData(x))
print(splitData(x, y))
print(splitData(x, y, ratio=[0.5, 0.25]))

[0 1 2 3]
[0 3 1 2]
(array([[1, 2],
       [4, 5],
       [2, 3]]), array([[3, 4]]), array([], shape=(0, 2), dtype=int64))
[0 1 2 3]
[1 3 2 0]
((array([[2, 3],
       [4, 5],
       [3, 4]]), array([1, 3, 2])), (array([[1, 2]]), array([0])), (array([], shape=(0, 2), dtype=int64), array([], dtype=int64)))
[0 1 2 3]
[2 0 1 3]
((array([[3, 4],
       [1, 2]]), array([2, 0])), (array([[2, 3]]), array([1])), (array([[4, 5]]), array([3])))


In [21]:
x

array([[1, 2],
       [2, 3],
       [3, 4],
       [4, 5]])