From c75f457650b25d4792917de2cbe686c725b72d16 Mon Sep 17 00:00:00 2001 From: Heiko Strathmann Date: Thu, 7 Jun 2018 11:08:40 +0100 Subject: [PATCH] make xvalidation always leave state of machine/features invariant previously, this was only done for number of threads >1 which lead to inconsistent behaviour --- src/shogun/evaluation/CrossValidation.cpp | 40 +++++++---------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/src/shogun/evaluation/CrossValidation.cpp b/src/shogun/evaluation/CrossValidation.cpp index 56c97e97725..4f16fb0ecc8 100644 --- a/src/shogun/evaluation/CrossValidation.cpp +++ b/src/shogun/evaluation/CrossValidation.cpp @@ -241,24 +241,14 @@ float64_t CCrossValidation::evaluate_one_run( CrossValidationFoldStorage* fold = new CrossValidationFoldStorage(); SG_REF(fold) - CMachine* machine; - CFeatures* features; - CLabels* labels; - CEvaluation* evaluation_criterion; + auto machine = (CMachine*)m_machine->clone(); - if (get_global_parallel()->get_num_threads() == 1) - { - machine = m_machine; - features = m_features; - evaluation_criterion = m_evaluation_criterion; - } - else - { - machine = (CMachine*)m_machine->clone(); - features = (CFeatures*)m_features->clone(); - evaluation_criterion = - (CEvaluation*)m_evaluation_criterion->clone(); - } + // TODO while these are not used through const interfaces, + // we unfortunately have to clone, even though these could be shared + auto features = (CFeatures*)m_features->clone(); + auto labels = (CLabels*)m_labels->clone(); + auto evaluation_criterion = + (CEvaluation*)m_evaluation_criterion->clone(); /* evtl. update xvalidation output class */ fold->set_run_index(index); @@ -271,10 +261,6 @@ float64_t CCrossValidation::evaluate_one_run( features->add_subset(inverse_subset_indices); /* set label subset for training */ - if (get_global_parallel()->get_num_threads() == 1) - labels = m_labels; - else - labels = machine->get_labels(); labels->add_subset(inverse_subset_indices); SG_DEBUG("training set %d:\n", i) @@ -287,6 +273,7 @@ float64_t CCrossValidation::evaluate_one_run( /* train machine on training features and remove subset */ SG_DEBUG("starting training\n") + machine->set_labels(labels); machine->train(features); SG_DEBUG("finished training\n") @@ -340,13 +327,10 @@ float64_t CCrossValidation::evaluate_one_run( /* clean up, remove subsets */ labels->remove_subset(); - if (get_global_parallel()->get_num_threads() != 1) - { - SG_UNREF(machine); - SG_UNREF(features); - SG_UNREF(labels); - SG_UNREF(evaluation_criterion); - } + SG_UNREF(machine); + SG_UNREF(features); + SG_UNREF(labels); + SG_UNREF(evaluation_criterion); SG_UNREF(result_labels); SG_UNREF(fold) }