### https://medium.com/@soumyachess1496/cross-validation-in-time-series-566ae4981ce4

In [17]:
import numpy as np
from sklearn.model_selection import TimeSeriesSplit
import matplotlib.pyplot as plt

In [3]:
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([1, 2, 3, 4, 5, 6])
tscv = TimeSeriesSplit()
print(tscv)

TimeSeriesSplit(gap=0, max_train_size=None, n_splits=5, test_size=None)


In [6]:
for train_index, test_index in tscv.split(X):
    print('TRAIN:', train_index, 'TEST:', test_index) 
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

TRAIN: [0] TEST: [1]
TRAIN: [0 1] TEST: [2]
TRAIN: [0 1 2] TEST: [3]
TRAIN: [0 1 2 3] TEST: [4]
TRAIN: [0 1 2 3 4] TEST: [5]


## Cross Validation on Time Series

### TimeSeriesSplit with Split

In [15]:
X = np.random.randn(12, 2)
y = np.random.randint(0, 2, 12)
print('X : ', X)
print('y : ', y)
tscv = TimeSeriesSplit(n_splits=3, test_size=2)
for train_index, test_index in tscv.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

X :  [[-1.79056293  0.87866891]
 [ 0.80106972 -0.44265385]
 [-0.93416833 -0.26727398]
 [ 0.11162316  0.44384291]
 [ 0.09443055 -0.9949118 ]
 [-1.55649479 -1.37197826]
 [ 0.15181086  0.75668017]
 [ 0.03937572 -1.36762463]
 [ 0.97198744  0.39055339]
 [ 0.60791456  0.79445008]
 [-0.30140106 -1.5759772 ]
 [-1.2601332  -0.01146805]]
y :  [1 1 0 1 0 1 0 1 0 0 0 0]
TRAIN: [0 1 2 3 4 5] TEST: [6 7]
TRAIN: [0 1 2 3 4 5 6 7] TEST: [8 9]
TRAIN: [0 1 2 3 4 5 6 7 8 9] TEST: [10 11]


### TimeSeriesSplit without Split - default split is 5

In [14]:
X = np.random.randn(12, 2)
y = np.random.randint(0, 2, 12)
print('X : ', X)
print('y : ', y)
tscv = TimeSeriesSplit()
for train_index, test_index in tscv.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

X :  [[ 0.83749989  0.71384347]
 [-0.63799982 -1.92607307]
 [-1.26717527 -0.50270929]
 [ 1.3819399  -1.37599415]
 [ 1.67738402  0.45264177]
 [ 0.44576145 -0.31876357]
 [-2.06525453 -0.13837251]
 [-0.852445    0.92908317]
 [ 0.15408028 -0.36622604]
 [-1.16253471  0.4851462 ]
 [ 0.67267134  0.60131115]
 [-1.14238636  0.53727494]]
y :  [1 1 0 1 1 1 1 1 0 1 0 1]
TRAIN: [0 1] TEST: [2 3]
TRAIN: [0 1 2 3] TEST: [4 5]
TRAIN: [0 1 2 3 4 5] TEST: [6 7]
TRAIN: [0 1 2 3 4 5 6 7] TEST: [8 9]
TRAIN: [0 1 2 3 4 5 6 7 8 9] TEST: [10 11]


### TimeSeriesSplit with Gap

In [19]:
def plot_cv_indices(cv, X, y, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    # Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=None)):
        # Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # Visualize the results
        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=cmap_cv,
                   vmin=-.2, vmax=1.2)
        
    # Plot the data classes and groups at the end
    ax.scatter(range(len(X)), [ii + 1.5] * len(X),
               c=y, marker='_', lw=lw, cmap=cmap_data)

    # Formatting
    yticklabels = list(range(n_splits)) + ['class']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+1.2, -.1], xlim=[0, 100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
    return ax

In [23]:
tscv = TimeSeriesSplit(n_splits=3, test_size=2, gap=2)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
for train_index, test_index in tscv.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

TRAIN: [0 1 2 3] TEST: [6 7]
TRAIN: [0 1 2 3 4 5] TEST: [8 9]
TRAIN: [0 1 2 3 4 5 6 7] TEST: [10 11]
