# Investigating the Impact of Cluster Sampling on the Size of the Training, Testing, and Validation Sets

When training the MVN NGBoost model on the drifter data, we must take extra care to ensure that information about the testing and validation sets are not inadvertedly introduced into the training data. The ~400,000 drifter observations that form our data set come from ~2000 unique drifters. We expect that observations taken by the same drifter are likely to be highly correlated so we ensure that all of the observations made by any single drifter are in precisely one of the training, testing or validation sets. Ensuring observations from each drifter are not split between sets will ensure that the training data does not contain any extra information via correlation.

O'Malley et al. (2023) deal with this issue using cluster sampling. That is, spliting the data into clusters defined by their corresponding drifter ID, then randomly sampling the clusters into the training, testing, and validation sets containing 81%, 10% and 9% of the drifter IDs respectively. However, there is significant variation between the number of observations found in each of the drifter ID clusters meaning that the propertion of the overall data found in each of the sets may be significantly different than the nominal 81-10-9 split. At its most extreme, this discrepancy could result in testing and training sets that are of comporable sizes. 

In this notebook, we will investigate whether the sizes of the training, testing, and validation sets that result from 81-10-9 cluster sampling differ signficantly from the nominal 81-10-9 values.

In [1]:
# import modules and load data
import sklearn
import pandas as pd
import numpy as np
path_to_data = '../data/filtered_nao_drifters_with_sst_gradient.h5'
data = pd.read_hdf(path_to_data)
# add day of the year as an index (to be added to the data later)
data['day_of_year'] = data['time'].apply(lambda t : t.timetuple().tm_yday)

ImportError: Missing optional dependency 'pytables'.  Use pip or conda to install pytables.

In [None]:
# split the drifter IDs into training, testing and validation 

def train_test_validation_split():
    pass