Skip to content

Commit

Permalink
feat: add function to generate kfold splits
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Jan 1, 2020
1 parent ba030e2 commit bb5c316
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions ludwig/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,15 @@ def full_train(
)


def generate_kfold_splits(data_train_df_fp, k_fold):
kf = KFold(n_splits=k_fold, shuffle=True)
data_train_df = pd.read_csv(data_train_df_fp)
fold_num = 0
for train_index, test_index in kf.split(data_train_df):
fold_num += 1
yield train_index, test_index, fold_num


def kfold_cross_validate(
model_definition,
model_definition_file=None,
Expand All @@ -422,21 +431,18 @@ def kfold_cross_validate(

# place each fold in a separate directory
data_dir = os.path.dirname(data_train_csv)
kf = KFold(n_splits=k_fold, shuffle=True)
i = 0
kfold_training_stats = {}
for train_index, test_index in kf.split(data_df):
for train_index, test_index, fold_num in generate_kfold_splits(data_train_csv, k_fold):
with tempfile.TemporaryDirectory(dir=data_dir) as temp_dir_name:
# save training and validation subset for the fold into a temporary directory
train_csv_fp = os.path.join(temp_dir_name, 'train_fold.csv')
test_csv_fp = os.path.join(temp_dir_name, 'test_fold.csv')
i += 1
logger.info("\n\n>>>>> for fold {:d} created temporary directory: {}".format(i, temp_dir_name))
logger.info("\n\n>>>>> for fold {:d} created temporary directory: {}".format(fold_num, temp_dir_name))
data_df.iloc[train_index].to_csv(train_csv_fp, index=False)
data_df.iloc[test_index].to_csv(test_csv_fp, index=False)

# train and validate model on this fold
logger.info("training on fold {:d}".format(i))
logger.info("training on fold {:d}".format(fold_num))
(model,
preprocessed_data,
_,
Expand All @@ -446,7 +452,7 @@ def kfold_cross_validate(
data_train_csv = train_csv_fp,
data_test_csv = test_csv_fp,
experiment_name='cross_validation',
model_name='fold_' + str(i),
model_name='fold_' + str(fold_num),
output_directory=os.path.join(temp_dir_name,'results'))

# score on hold out fold
Expand All @@ -457,7 +463,7 @@ def kfold_cross_validate(
train_stats['fold_metric'] = preds['combined']

# collect training statistics for this fold
kfold_training_stats['fold_'+str(i)] = train_stats
kfold_training_stats['fold_'+str(fold_num)] = train_stats

# save consolidated training statistics from k-fold cv runs
save_json(os.path.join(output_directory,'kfold_training_statistics.json'), kfold_training_stats)
Expand Down

0 comments on commit bb5c316

Please sign in to comment.