In [1]:
import pandas as pd
import numpy as np

from sklearn.model_selection import TimeSeriesSplit

In [3]:
date_list = pd.date_range(start='2021-01-01', end='2021-12-31', freq='D')
date_list

DatetimeIndex(['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04',
               '2021-01-05', '2021-01-06', '2021-01-07', '2021-01-08',
               '2021-01-09', '2021-01-10',
               ...
               '2021-12-22', '2021-12-23', '2021-12-24', '2021-12-25',
               '2021-12-26', '2021-12-27', '2021-12-28', '2021-12-29',
               '2021-12-30', '2021-12-31'],
              dtype='datetime64[ns]', length=365, freq='D')

In [4]:
y = np.random.randint(1,100, 365)
y[:10]

array([11, 54, 23, 95, 68, 85, 38, 83, 66, 73])

In [5]:
df = pd.DataFrame(index=date_list)
df.head()

2021-01-01
2021-01-02
2021-01-03
2021-01-04
2021-01-05


In [6]:
df['y'] = y
df.head()

Unnamed: 0,y
2021-01-01,11
2021-01-02,54
2021-01-03,23
2021-01-04,95
2021-01-05,68


In [7]:
help(TimeSeriesSplit)

Help on class TimeSeriesSplit in module sklearn.model_selection._split:

class TimeSeriesSplit(_BaseKFold)
 |  TimeSeriesSplit(n_splits=5, *, max_train_size=None)
 |  
 |  Time Series cross-validator
 |  
 |  .. versionadded:: 0.18
 |  
 |  Provides train/test indices to split time series data samples
 |  that are observed at fixed time intervals, in train/test sets.
 |  In each split, test indices must be higher than before, and thus shuffling
 |  in cross validator is inappropriate.
 |  
 |  This cross-validation object is a variation of :class:`KFold`.
 |  In the kth split, it returns first k folds as train set and the
 |  (k+1)th fold as test set.
 |  
 |  Note that unlike standard cross-validation methods, successive
 |  training sets are supersets of those that come before them.
 |  
 |  Read more in the :ref:`User Guide <cross_validation>`.
 |  
 |  Parameters
 |  ----------
 |  n_splits : int, default=5
 |      Number of splits. Must be at least 2.
 |  
 |      .. versionchange

In [8]:
train_data = df.iloc[:300, :]
test_data = df.iloc[300:, :]

print(train_data.shape)
print(test_data.shape)

(300, 1)
(65, 1)


In [70]:
X = train_data.index
y = train_data['y']


tscv = TimeSeriesSplit(n_splits=10)
for train_index, test_index in tscv.split(X[:20]):
    print('TRAIN: ', train_index, "TEST:", test_index)
#    X_train, X_test = X[train_index], X[test_index]
#    y_train, y_test = y[train_index], y[test_index]
    
    

TRAIN:  [0 1 2 3 4 5 6 7 8 9] TEST: [10]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10] TEST: [11]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11] TEST: [12]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12] TEST: [13]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13] TEST: [14]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] TEST: [15]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15] TEST: [16]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16] TEST: [17]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17] TEST: [18]
TRAIN:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18] TEST: [19]


In [63]:
def train_test_split(df, n=10):
    df_length = df.shape[0]
    i = n
    train_index = []
    test_index = []
    initial_index = 0
    while df_length - initial_index >= n:
        train_index.append(np.arange(initial_index,i))
        test_index.append(i)
        initial_index = i
        i+=n
    train_index.append(np.arange(initial_index,df_length - 1))
    test_index.append(df_length - 1)
    return train_index, test_index

In [67]:
train_test_split(X[:102])

([array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
  array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
  array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39]),
  array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49]),
  array([50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
  array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69]),
  array([70, 71, 72, 73, 74, 75, 76, 77, 78, 79]),
  array([80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
  array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
  array([100])],
 [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 101])