Skip to content

Commit

Permalink
Merge pull request #3660 from lkuchenb/feature/parallel_xval_reintegr…
Browse files Browse the repository at this point in the history
…ation

Reintegrate parallel crossvalidation
  • Loading branch information
vigsterkr committed Feb 28, 2017
2 parents be8e005 + 6eb1b2a commit 277aeb7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
69 changes: 49 additions & 20 deletions src/shogun/evaluation/CrossValidation.cpp
Expand Up @@ -264,32 +264,49 @@ float64_t CCrossValidation::evaluate_one_run()
m_machine->set_store_model_features(true);

/* do actual cross-validation */
#pragma omp parallel for
for (index_t i=0; i <num_subsets; ++i)
{
CMachine* machine;
CFeatures* features;
CLabels* labels;

if (get_global_parallel()->get_num_threads()==1)
machine=m_machine;
else
machine=(CMachine*)m_machine->clone();

/* evtl. update xvalidation output class */
CCrossValidationOutput* current=(CCrossValidationOutput*)
m_xval_outputs->get_first_element();
#pragma omp critical
{
while (current)
{
current->update_fold_index(i);
SG_UNREF(current);
current=(CCrossValidationOutput*)
m_xval_outputs->get_next_element();
}
}

/* set feature subset for training */
SGVector<index_t> inverse_subset_indices=
m_splitting_strategy->generate_subset_inverse(i);
m_features->add_subset(inverse_subset_indices);
for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
{
CPreprocessor* preprocessor = m_features->get_preprocessor(p);
preprocessor->init(m_features);
SG_UNREF(preprocessor);
}

if (get_global_parallel()->get_num_threads()==1)
features=m_features;
else
features=(CFeatures*)m_features->clone();

features->add_subset(inverse_subset_indices);

/* set label subset for training */
m_labels->add_subset(inverse_subset_indices);
if (get_global_parallel()->get_num_threads()==1)
labels=m_labels;
else
labels=machine->get_labels();
labels->add_subset(inverse_subset_indices);

SG_DEBUG("training set %d:\n", i)
if (io->get_loglevel()==MSG_DEBUG)
Expand All @@ -300,30 +317,33 @@ float64_t CCrossValidation::evaluate_one_run()

/* train machine on training features and remove subset */
SG_DEBUG("starting training\n")
m_machine->train(m_features);
machine->train(features);
SG_DEBUG("finished training\n")

/* evtl. update xvalidation output class */
#pragma omp critical
{
current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
while (current)
{
current->update_train_indices(inverse_subset_indices, "\t");
current->update_trained_machine(m_machine, "\t");
current->update_trained_machine(machine, "\t");
SG_UNREF(current);
current=(CCrossValidationOutput*)
m_xval_outputs->get_next_element();
}
}

m_features->remove_subset();
m_labels->remove_subset();
features->remove_subset();
labels->remove_subset();

/* set feature subset for testing (subset method that stores pointer) */
SGVector<index_t> subset_indices =
m_splitting_strategy->generate_subset_indices(i);
m_features->add_subset(subset_indices);
features->add_subset(subset_indices);

/* set label subset for testing */
m_labels->add_subset(subset_indices);
labels->add_subset(subset_indices);

SG_DEBUG("test set %d:\n", i)
if (io->get_loglevel()==MSG_DEBUG)
Expand All @@ -334,33 +354,42 @@ float64_t CCrossValidation::evaluate_one_run()

/* apply machine to test features and remove subset */
SG_DEBUG("starting evaluation\n")
SG_DEBUG("%p\n", m_features)
CLabels* result_labels=m_machine->apply(m_features);
SG_DEBUG("%p\n", features)
CLabels* result_labels=machine->apply(features);
SG_DEBUG("finished evaluation\n")
m_features->remove_subset();
features->remove_subset();
SG_REF(result_labels);

/* evaluate */
results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
results[i]=m_evaluation_criterion->evaluate(result_labels, labels);
SG_DEBUG("result on fold %d is %f\n", i, results[i])

/* evtl. update xvalidation output class */
#pragma omp critical
{
current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
while (current)
{
current->update_test_indices(subset_indices, "\t");
current->update_test_result(result_labels, "\t");
current->update_test_true_result(m_labels, "\t");
current->update_test_true_result(labels, "\t");
current->post_update_results();
current->update_evaluation_result(results[i], "\t");
SG_UNREF(current);
current=(CCrossValidationOutput*)
m_xval_outputs->get_next_element();
}
}

/* clean up, remove subsets */
labels->remove_subset();
if (get_global_parallel()->get_num_threads()!=1)
{
SG_UNREF(machine);
SG_UNREF(features);
SG_UNREF(labels);
}
SG_UNREF(result_labels);
m_labels->remove_subset();
}

SG_DEBUG("done unlocked evaluation\n", get_name())
Expand Down
5 changes: 5 additions & 0 deletions src/shogun/evaluation/CrossValidation.h
Expand Up @@ -108,6 +108,11 @@ class CCrossValidationResult : public CEvaluationResult
* speed up computations. Can be turned off by the set_autolock() method.
* Locking in general may speed up things (eg for kernel machines the kernel
* matrix is precomputed), however, it is not always supported.
*
* Crossvalidation runs with current number of threads
* (Parallel::set_num_threads) for unlocked case, and currently duplicates all
* objects (might be changed later).
*
*/
class CCrossValidation: public CMachineEvaluation
{
Expand Down

0 comments on commit 277aeb7

Please sign in to comment.