# Test the algorithm to split indices and compared with sklearn's train_test_split

In [33]:
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

In [34]:
np.random.seed(0)
X = np.random.randn(100, 5)
Y = np.random.randn(100, 3)

dfX = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
dfY = pd.DataFrame(Y, columns=[f"target_{i}" for i in range(Y.shape[1])])

display(dfX)
display(dfY)

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4
0,1.764052,0.400157,0.978738,2.240893,1.867558
1,-0.977278,0.950088,-0.151357,-0.103219,0.410599
2,0.144044,1.454274,0.761038,0.121675,0.443863
3,0.333674,1.494079,-0.205158,0.313068,-0.854096
4,-2.552990,0.653619,0.864436,-0.742165,2.269755
...,...,...,...,...,...
95,0.994394,1.319137,-0.882419,1.128594,0.496001
96,0.771406,1.029439,-0.908763,-0.424318,0.862596
97,-2.655619,1.513328,0.553132,-0.045704,0.220508
98,-1.029935,-0.349943,1.100284,1.298022,2.696224


Unnamed: 0,target_0,target_1,target_2
0,0.382732,-0.034242,1.096347
1,-0.234216,-0.347451,-0.581268
2,-1.632635,-1.567768,-1.179158
3,1.301428,0.895260,1.374964
4,-1.332212,-1.968625,-0.660056
...,...,...,...
95,0.148450,0.529045,0.422629
96,-1.359781,-0.041401,-0.757871
97,-0.050084,-0.897401,1.312470
98,-0.858972,-0.898942,0.074586


In [35]:
dfX_train, dfX_test, dfY_train, dfY_test = train_test_split(
    dfX, dfY, test_size=0.2, random_state=42
)

In [36]:
dfX_train

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4
55,-0.390953,0.493742,-0.116104,-2.030684,2.064493
88,-0.395229,-1.159421,-0.085931,0.194293,0.875833
26,-0.769916,0.539249,-0.674333,0.031831,-0.635846
42,0.910179,0.317218,0.786328,-0.466419,-0.944446
69,-0.280355,-0.364694,0.156704,0.578521,0.349654
...,...,...,...,...,...
60,-1.306527,1.658131,-0.118164,-0.680178,0.666383
71,-0.521189,-1.843070,-0.477974,-0.479656,0.620358
14,0.729091,0.128983,1.139401,-1.234826,0.402342
92,-0.517519,-0.978830,-0.439190,0.181338,-0.502817


In [37]:
dfX_test

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4
83,0.684501,0.370825,0.142062,1.519995,1.719589
53,0.188779,0.523891,0.088422,-0.310886,0.0974
70,-0.764144,-1.437791,1.364532,-0.689449,-0.652294
45,0.063262,0.156507,0.232181,-0.597316,-0.237922
44,-0.955945,-0.345982,-0.463596,0.481481,-1.540797
39,-0.171546,0.771791,0.823504,2.163236,1.336528
22,1.867559,0.906045,-0.861226,1.910065,-0.268003
80,-0.598654,-1.115897,0.766663,0.356293,-1.768538
10,-0.895467,0.386902,-0.510805,-1.180632,-0.028182
0,1.764052,0.400157,0.978738,2.240893,1.867558


In [38]:
indices = np.arange(len(dfX))
indices

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])

In [39]:
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
print(train_indices)
print(test_indices)

[55 88 26 42 69 15 40 96  9 72 11 47 85 28 93  5 66 65 35 16 49 34  7 95
 27 19 81 25 62 13 24  3 17 38  8 78  6 64 36 89 56 99 54 43 50 67 46 68
 61 97 79 41 58 48 98 57 75 32 94 59 63 84 37 29  1 52 21  2 23 87 91 74
 86 82 20 60 71 14 92 51]
[83 53 70 45 44 39 22 80 10  0 18 30 73 33 90  4 76 77 12 31]


In [40]:
# Verify that the split indices correspond to the data splits
dfX_train_idx = dfX_train.index.values
dfX_test_idx = dfX_test.index.values

In [41]:
dfX_train_idx == train_indices

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True])

In [42]:
dfX_test_idx == test_indices

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True])