Skip to content

Commit

Permalink
renaming k_fold parameter to num_folds everywhere but in full_kfold_c…
Browse files Browse the repository at this point in the history
…ross_validate
  • Loading branch information
w4nderlust committed Feb 11, 2020
1 parent ae8699f commit 3928395
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def test(


def kfold_cross_validate(
k_fold,
num_folds,
model_definition=None,
model_definition_file=None,
data_csv=None,
Expand All @@ -1015,7 +1015,7 @@ def kfold_cross_validate(
# Inputs
:param k_fold: (int) number of folds to create for the cross-validation
:param num_folds: (int) number of folds to create for the cross-validation
:param model_definition: (dict, default: None) a dictionary containing
information needed to build a model. Refer to the [User Guide]
(http://ludwig.ai/user_guide/#model-definition) for details.
Expand All @@ -1037,7 +1037,7 @@ def kfold_cross_validate(

(kfold_cv_stats,
kfold_split_indices) = experiment_kfold_cross_validate(
k_fold,
num_folds,
model_definition=model_definition,
model_definition_file=model_definition_file,
data_csv=data_csv,
Expand Down
10 changes: 5 additions & 5 deletions ludwig/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def full_experiment(


def kfold_cross_validate(
k_fold,
num_folds,
model_definition=None,
model_definition_file=None,
data_csv=None,
Expand All @@ -405,7 +405,7 @@ def kfold_cross_validate(
**kwargs
):
# check for k_fold
if k_fold is None:
if num_folds is None:
raise ValueError(
'k_fold parameter must be specified'
)
Expand All @@ -422,7 +422,7 @@ def kfold_cross_validate(
'model_definition_file can be provided'
)

logger.info('starting {:d}-fold cross validation'.format(k_fold))
logger.info('starting {:d}-fold cross validation'.format(num_folds))

# extract out model definition for use
if model_definition_file is not None:
Expand All @@ -444,7 +444,7 @@ def kfold_cross_validate(
kfold_split_indices = {}

for train_indices, test_indices, fold_num in \
generate_kfold_splits(data_df, k_fold, random_seed):
generate_kfold_splits(data_df, num_folds, random_seed):
with tempfile.TemporaryDirectory(dir=data_dir) as temp_dir_name:
curr_train_df = data_df.iloc[train_indices]
curr_test_df = data_df.iloc[test_indices]
Expand Down Expand Up @@ -519,7 +519,7 @@ def kfold_cross_validate(

kfold_cv_stats['overall'] = overall_kfold_stats

logger.info('completed {:d}-fold cross validation'.format(k_fold))
logger.info('completed {:d}-fold cross validation'.format(num_folds))

return kfold_cv_stats, kfold_split_indices

Expand Down

0 comments on commit 3928395

Please sign in to comment.