Skip to content

Commit

Permalink
refactor: incorporate review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Jan 11, 2020
1 parent 4d2662e commit 1861d6f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions ludwig/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,10 @@ def full_train(


def kfold_cross_validate(
model_definition,
k_fold,
model_definition_file=None,
data_train_csv=None,
output_directory='results',
k_fold=None,
**kwargs
):

Expand Down Expand Up @@ -449,8 +448,12 @@ def kfold_cross_validate(
)

# score on hold out fold
preds = model.predict(preprocessed_data[2],
model_definition['training']['batch_size'])
eval_batch_size = model_definition['training']['eval_batch_size']
batch_size = model_definition['training']['batch_size']
preds = model.predict(
preprocessed_data[2],
eval_batch_size if eval_batch_size != 0 else batch_size
)

# augment the training statistics with scoring metric fron the hold out fold
train_stats['fold_metric'] = preds['combined']
Expand Down

0 comments on commit 1861d6f

Please sign in to comment.