Skip to content

Commit

Permalink
make xvalidation always leave state of machine/features invariant
Browse files Browse the repository at this point in the history
previously, this was only done for number of threads >1 which lead to inconsistent behaviour
  • Loading branch information
karlnapf committed Jun 7, 2018
1 parent cd08c4e commit c75f457
Showing 1 changed file with 12 additions and 28 deletions.
40 changes: 12 additions & 28 deletions src/shogun/evaluation/CrossValidation.cpp
Expand Up @@ -241,24 +241,14 @@ float64_t CCrossValidation::evaluate_one_run(
CrossValidationFoldStorage* fold = new CrossValidationFoldStorage();
SG_REF(fold)

CMachine* machine;
CFeatures* features;
CLabels* labels;
CEvaluation* evaluation_criterion;
auto machine = (CMachine*)m_machine->clone();

if (get_global_parallel()->get_num_threads() == 1)
{
machine = m_machine;
features = m_features;
evaluation_criterion = m_evaluation_criterion;
}
else
{
machine = (CMachine*)m_machine->clone();
features = (CFeatures*)m_features->clone();
evaluation_criterion =
(CEvaluation*)m_evaluation_criterion->clone();
}
// TODO while these are not used through const interfaces,
// we unfortunately have to clone, even though these could be shared
auto features = (CFeatures*)m_features->clone();
auto labels = (CLabels*)m_labels->clone();
auto evaluation_criterion =
(CEvaluation*)m_evaluation_criterion->clone();

/* evtl. update xvalidation output class */
fold->set_run_index(index);
Expand All @@ -271,10 +261,6 @@ float64_t CCrossValidation::evaluate_one_run(
features->add_subset(inverse_subset_indices);

/* set label subset for training */
if (get_global_parallel()->get_num_threads() == 1)
labels = m_labels;
else
labels = machine->get_labels();
labels->add_subset(inverse_subset_indices);

SG_DEBUG("training set %d:\n", i)
Expand All @@ -287,6 +273,7 @@ float64_t CCrossValidation::evaluate_one_run(

/* train machine on training features and remove subset */
SG_DEBUG("starting training\n")
machine->set_labels(labels);
machine->train(features);
SG_DEBUG("finished training\n")

Expand Down Expand Up @@ -340,13 +327,10 @@ float64_t CCrossValidation::evaluate_one_run(

/* clean up, remove subsets */
labels->remove_subset();
if (get_global_parallel()->get_num_threads() != 1)
{
SG_UNREF(machine);
SG_UNREF(features);
SG_UNREF(labels);
SG_UNREF(evaluation_criterion);
}
SG_UNREF(machine);
SG_UNREF(features);
SG_UNREF(labels);
SG_UNREF(evaluation_criterion);
SG_UNREF(result_labels);
SG_UNREF(fold)
}
Expand Down

0 comments on commit c75f457

Please sign in to comment.