In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pytest
import numpy as np
import pandas as pd
from numpy.testing import assert_array_equal

from sklearn.utils.estimator_checks import check_estimator
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier

from sklearn_questions import KNearestNeighbors
from sklearn_questions import MonthlySplit


In [3]:
k = 3

X, y = make_classification(n_samples=200, n_features=20,
                               random_state=42)
X_train, X_test, y_train, y_test = \
    train_test_split(X, y, random_state=42)
knn = KNeighborsClassifier(n_neighbors=k)
y_pred_sk = knn.fit(X_train, y_train).predict(X_test)

onn = KNearestNeighbors(k)
onn.fit(X_train, y_train)
y_pred_me = onn.predict(X_test)

assert_array_equal(y_pred_me, y_pred_sk)

assert onn.score(X_test, y_test) == knn.score(X_test, y_test)

In [4]:
end_date = '2021-01-31'
shuffle_data = False
expected_splits = 12

In [5]:
date = pd.date_range(start='2020-01-01', end=end_date, freq='D')
n_samples = len(date)
X = pd.DataFrame(range(n_samples), index=date, columns=['val'])
y = pd.DataFrame(
    np.array([i % 2 for i in range(n_samples)]),
    index=date
)

if shuffle_data:
    X, y = shuffle(X, y, random_state=0)

X_1d = X['val']

cv = MonthlySplit()
cv_repr = "MonthlySplit(time_col='index')"

# Test if the repr works without any errors
assert cv_repr == repr(cv)

In [6]:
# Test if get_n_splits works correctly
assert cv.get_n_splits(X, y) == expected_splits

In [7]:
for train_split, test_split in cv.split(X, y):
    a = X.iloc[train_split]

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


In [8]:
train_split

array([335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347,
       348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360,
       361, 362, 363, 364, 365])

In [10]:
list(cv.split(X, y))

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


[(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]),
  array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
         48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59])),
 (array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
         48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
  array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90])),
 (array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90]),
  array([ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
         104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
         117, 118, 119, 120])),
 (array([ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
         104, 105, 106, 107, 108, 109, 1

In [11]:
X.iloc[list(cv.split(X, y))[10][0]].index.month

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


Index([11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
       11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
      dtype='int32')

In [12]:
# Test if the cross-validator works as expected even if
# the data is 1d
np.testing.assert_equal(
    list(cv.split(X, y)), list(cv.split(X_1d, y))
)

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021
train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


In [13]:
# Test that train, test indices returned are integers and
# data is correctly ordered
for train, test in cv.split(X, y):
    assert np.asarray(train).dtype.kind == "i"
    assert np.asarray(test).dtype.kind == "i"

    X_train, X_test = X.iloc[train], X.iloc[test]
    y_train, y_test = y.iloc[train], y.iloc[test]
    assert X_train.index.max() < X_test.index.min()
    assert y_train.index.max() < y_test.index.min()
    assert X.index.equals(y.index)

with pytest.raises(ValueError, match='datetime'):
    cv = MonthlySplit(time_col='val')
    next(cv.split(X, y))

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


In [18]:
end_date = '2021-01-31'#, '2020-12-31']
shuffle_data = False

In [19]:
date = pd.date_range(
    start='2020-01-01 00:00', end=end_date, freq='D'
)
n_samples = len(date)
X = pd.DataFrame({'val': range(n_samples), 'date': date})
y = pd.DataFrame(
    np.array([i % 2 for i in range(n_samples)])
)

if shuffle_data:
    X, y = shuffle(X, y, random_state=0)

cv = MonthlySplit(time_col='date')

# Test that train, test indices returned are integers and
# data is correctly ordered
n_splits = 0
last_time = None

In [20]:
list(cv.split(X, y))

train 1-2020, test 2-2020
train 2-2020, test 3-2020
train 3-2020, test 4-2020
train 4-2020, test 5-2020
train 5-2020, test 6-2020
train 6-2020, test 7-2020
train 7-2020, test 8-2020
train 8-2020, test 9-2020
train 9-2020, test 10-2020
train 10-2020, test 11-2020
train 11-2020, test 12-2020
train 12-2020, test 1-2021


[(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]),
  array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
         48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59])),
 (array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
         48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
  array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90])),
 (array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90]),
  array([ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
         104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
         117, 118, 119, 120])),
 (array([ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
         104, 105, 106, 107, 108, 109, 1

In [17]:
for train, test in cv.split(X, y):

    X_train, X_test = X.iloc[train], X.iloc[test]
    assert X_train['date'].max() < X_test['date'].min()
    assert X_train['date'].dt.month.nunique() == 1
    assert X_test['date'].dt.month.nunique() == 1
    assert X_train['date'].dt.year.nunique() == 1
    assert X_test['date'].dt.year.nunique() == 1
    if last_time is not None:
        assert X_test['date'].min() > last_time
    last_time = X_test['date'].max()
    n_splits += 1

assert 'idx' not in X.columns

assert n_splits == cv.get_n_splits(X, y)

train 1-2020, test 2-2020


AssertionError: 