In [1]:
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder

from sklego.model_selection import GroupTimeSeriesSplit

import pandas as pd


# initialze train data
train = pd.read_csv("wolf_train.csv", index_col=0)
X_train = train.drop(columns="wolf_seen")
y_train = train['wolf_seen']
groups_train = train['year']

# initialize classifier and cross validator
clf = RandomForestClassifier()
cv = GroupTimeSeriesSplit(4)

# initialize param grid for Grid Searching
param_grid = {
                'randomforestclassifier__n_estimators': [5, 10, 20],
                'randomforestclassifier__max_depth': [2, 5, 9],
                'randomforestclassifier__bootstrap': [True, False],
                'randomforestclassifier__min_samples_leaf': [1, 2, 4],
                'randomforestclassifier__min_samples_split': [2, 5, 10],
             }

# define pipeline
pipe = make_pipeline(
    ColumnTransformer([("cat", OneHotEncoder(handle_unknown="ignore"), ["season","day","landscape","method"])]),
    RandomForestClassifier(class_weight="balanced_subsample", random_state=42)
)

# define the gridsearch
grid_clf = GridSearchCV(pipe, param_grid=param_grid, cv=cv.split(X_train, y_train, groups_train), refit=True)

# fit the model
grid_clf.fit(X_train, y_train)

# analyze the splits of GroupTimeSeriesSplit:
(
    cv
    .summary()
    .groupby('group')
    .agg(
        start_year=('index', 'min'),
        last_year=('index', 'max'),
        observations=('observations', 'sum'),
        ideal_group_size=('ideal_group_size', 'first'),
        diff_from_ideal_group_size=('diff_from_ideal_group_size', 'first')
    )
)

Unnamed: 0_level_0,start_year,last_year,observations,ideal_group_size,diff_from_ideal_group_size
group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,1980,1997,1049,1036,13
1,1998,2004,990,1036,-46
2,2005,2010,1037,1036,1
3,2011,2015,1153,1036,117
4,2016,2019,951,1036,-85
