In [1]:
from sklearn.cross_validation import train_test_split 
import pandas as pd
import numpy as np

>*train_test_split*

> Split arrays or matrices into random train and test subsets

> Quick utility that wraps input validation and next(iter(ShuffleSplit(n_samples))) and application to input data into a single call for splitting (and optionally subsampling) data in a oneliner.


# Don't test where you train

I cannot stress this enough with Machine Learning. When you test against the same data you used to train your model - you will get biased results that are heavily in favor of your model.  This can paint expectations that aren't realistic with your particular prediction model.

In order to mitigate against this we need to split our entire data set, typically in a 2/3 : 1/3 split where:

 - 2/3 of the data is used to *train* your classifier
 - 1/3 of the data is used to *score* your predictions
 
Let's explore how this is done with the `train_test_split` function from `sklearn`

In [5]:
df = pd.DataFrame(np.random.rand(20, 20))

In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 20 entries, 0 to 19
Data columns (total 20 columns):
0     20 non-null float64
1     20 non-null float64
2     20 non-null float64
3     20 non-null float64
4     20 non-null float64
5     20 non-null float64
6     20 non-null float64
7     20 non-null float64
8     20 non-null float64
9     20 non-null float64
10    20 non-null float64
11    20 non-null float64
12    20 non-null float64
13    20 non-null float64
14    20 non-null float64
15    20 non-null float64
16    20 non-null float64
17    20 non-null float64
18    20 non-null float64
19    20 non-null float64
dtypes: float64(20)
memory usage: 3.3 KB


When we pass our data frame and our columns to split up it will give us 4 variables in return.

  - A set of training rows
  - A set of testing rows
  - A set of training columns
  - A set of testing columns

In [19]:
a_train, a_test, b_train, b_test = train_test_split(df, df.columns, test_size=0.33)

Lets take a look at the datas:

In [23]:
a_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
12,0.970403,0.708571,0.998744,0.928082,0.855476,0.775062,0.972785,0.788251,0.020014,0.849588,0.825154,0.712335,0.064031,0.988875,0.134767,0.529837,0.458691,0.631557,0.829496,0.58989
5,0.151598,0.617705,0.977788,0.536776,0.919199,0.537046,0.952154,0.65937,0.217122,0.243799,0.156505,0.396018,0.551522,0.047119,0.089754,0.580041,0.025621,0.262723,0.088457,0.435452
10,0.673584,0.231551,0.475732,0.548999,0.155537,0.088409,0.561271,0.290757,0.623942,0.798143,0.302051,0.919332,0.042444,0.370164,0.311087,0.116597,0.472446,0.485228,0.316859,0.351818
16,0.789376,0.747166,0.78614,0.961024,0.964004,0.920879,0.045742,0.220985,0.426392,0.337657,0.471051,0.915705,0.097256,0.681336,0.958656,0.192472,0.706834,0.992227,0.071504,0.065386
2,0.634175,0.767076,0.350808,0.793821,0.424617,0.736844,0.44052,0.789312,0.489747,0.518073,0.571981,0.299183,0.487292,0.991872,0.6149,0.789361,0.341727,0.219032,0.647471,0.514597


In [24]:
a_test.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
17,0.041406,0.801079,0.254391,0.808014,0.516018,0.313591,0.524745,0.401558,0.371148,0.855183,0.004855,0.033155,0.179191,0.468404,0.864066,0.95939,0.960397,0.811633,0.96279,0.298365
11,0.588307,0.571047,0.358195,0.08411,0.887161,0.519337,0.244755,0.568301,0.702507,0.643895,0.783572,0.275994,0.662148,0.302002,0.12296,0.447097,0.078177,0.276677,0.780683,0.388227
3,0.890638,0.333972,0.016114,0.063376,0.64857,0.444806,0.994291,0.388455,0.654143,0.337419,0.20073,0.299403,0.297975,0.02681,0.697415,0.560548,0.081823,0.112725,0.065202,0.604099
9,0.965264,0.464182,0.723637,0.330365,0.723136,0.397838,0.674391,0.94034,0.801646,0.573835,0.269711,0.986313,0.410641,0.32423,0.709903,0.35063,0.083365,0.02241,0.200231,0.650912
7,0.070581,0.106885,0.504367,0.266632,0.264758,0.440598,0.987359,0.407747,0.197972,0.686632,0.851574,0.002224,0.416951,0.460715,0.221023,0.042986,0.740501,0.587717,0.192691,0.750716


In [25]:
b_train

Int64Index([12, 5, 10, 16, 2, 0, 14, 8, 1, 19, 18, 13, 15], dtype='int64')

In [26]:
b_test

Int64Index([17, 11, 3, 9, 7, 4, 6], dtype='int64')

Okay that was fun.  Lets simplify it a little more and not split our feature vectors.

In [14]:
df2 = pd.DataFrame(np.random.rand(20, 20))

When we pass only a dataframe to the function we will get back:

 - A training dataframe
 - A testing dataframe

In [15]:
a_train, a_test = train_test_split(df)

In [29]:
a_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
12,0.970403,0.708571,0.998744,0.928082,0.855476,0.775062,0.972785,0.788251,0.020014,0.849588,0.825154,0.712335,0.064031,0.988875,0.134767,0.529837,0.458691,0.631557,0.829496,0.58989
5,0.151598,0.617705,0.977788,0.536776,0.919199,0.537046,0.952154,0.65937,0.217122,0.243799,0.156505,0.396018,0.551522,0.047119,0.089754,0.580041,0.025621,0.262723,0.088457,0.435452
10,0.673584,0.231551,0.475732,0.548999,0.155537,0.088409,0.561271,0.290757,0.623942,0.798143,0.302051,0.919332,0.042444,0.370164,0.311087,0.116597,0.472446,0.485228,0.316859,0.351818
16,0.789376,0.747166,0.78614,0.961024,0.964004,0.920879,0.045742,0.220985,0.426392,0.337657,0.471051,0.915705,0.097256,0.681336,0.958656,0.192472,0.706834,0.992227,0.071504,0.065386
2,0.634175,0.767076,0.350808,0.793821,0.424617,0.736844,0.44052,0.789312,0.489747,0.518073,0.571981,0.299183,0.487292,0.991872,0.6149,0.789361,0.341727,0.219032,0.647471,0.514597


In [30]:
a_test.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
17,0.041406,0.801079,0.254391,0.808014,0.516018,0.313591,0.524745,0.401558,0.371148,0.855183,0.004855,0.033155,0.179191,0.468404,0.864066,0.95939,0.960397,0.811633,0.96279,0.298365
11,0.588307,0.571047,0.358195,0.08411,0.887161,0.519337,0.244755,0.568301,0.702507,0.643895,0.783572,0.275994,0.662148,0.302002,0.12296,0.447097,0.078177,0.276677,0.780683,0.388227
3,0.890638,0.333972,0.016114,0.063376,0.64857,0.444806,0.994291,0.388455,0.654143,0.337419,0.20073,0.299403,0.297975,0.02681,0.697415,0.560548,0.081823,0.112725,0.065202,0.604099
9,0.965264,0.464182,0.723637,0.330365,0.723136,0.397838,0.674391,0.94034,0.801646,0.573835,0.269711,0.986313,0.410641,0.32423,0.709903,0.35063,0.083365,0.02241,0.200231,0.650912
7,0.070581,0.106885,0.504367,0.266632,0.264758,0.440598,0.987359,0.407747,0.197972,0.686632,0.851574,0.002224,0.416951,0.460715,0.221023,0.042986,0.740501,0.587717,0.192691,0.750716
