From 78c3c8cece2b8edead68679f37c1a63fd36bcc66 Mon Sep 17 00:00:00 2001 From: lambday Date: Mon, 4 Jul 2016 12:57:56 +0100 Subject: [PATCH] optimized multi kernel permutation test cross validation --- ...tiKernelPermutationTestCrossValidation.cpp | 154 ++++++++++-------- .../KernelSelection_unittest.cc | 10 +- 2 files changed, 91 insertions(+), 73 deletions(-) diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp index 7ad4efa9e11..ad55b6acfa4 100644 --- a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp @@ -75,14 +75,14 @@ MultiKernelPermutationTestCrossValidation::~MultiKernelPermutationTestCrossValid void MultiKernelPermutationTestCrossValidation::add_term(terms_t& terms, float64_t val, index_t i, index_t j) { - if (i=j) { SG_SDEBUG("Adding Kernel(%d,%d)=%f to term_0!\n", i, j, val); terms.term[0]+=val; if (i==j) terms.diag[0]+=val; } - else if (i>=n_x && j>=n_x && i<=j) + else if (i>=n_x && j>=n_x && i>=j) { SG_SDEBUG("Adding Kernel(%d,%d)=%f to term_1!\n", i, j, val); terms.term[1]+=val; @@ -151,104 +151,124 @@ void MultiKernelPermutationTestCrossValidation::operator()(const KernelManager& SGVector statistic(kernel_mgr.num_kernels()); SGMatrix null_samples(kernel_mgr.num_kernels(), num_null_samples); Map null_samples_map(null_samples.data(), null_samples.num_rows, null_samples.num_cols); + auto stacks=std::vector >(num_null_samples); + for (auto n=0; n(new CSubsetStack()); + } - for (auto i=0; ibuild_subsets(); - kfold_y->build_subsets(); - for (auto j=0; j x_inds=kfold_x->generate_subset_inverse(j); - SGVector y_inds=kfold_y->generate_subset_inverse(j); - std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [this](index_t& val) { val += n_x; }); + kfold_x->build_subsets(); + kfold_y->build_subsets(); + for (auto j=0; j x_inds=kfold_x->generate_subset_inverse(j); + SGVector y_inds=kfold_y->generate_subset_inverse(j); + std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [this](index_t& val) { val += n_x; }); - SGVector xy_inds(x_inds.size()+y_inds.size()); - std::copy(x_inds.data(), x_inds.data()+x_inds.size(), xy_inds.data()); - std::copy(y_inds.data(), y_inds.data()+y_inds.size(), xy_inds.data()+x_inds.size()); + SGVector xy_inds(x_inds.size()+y_inds.size()); + std::copy(x_inds.data(), x_inds.data()+x_inds.size(), xy_inds.data()); + std::copy(y_inds.data(), y_inds.data()+y_inds.size(), xy_inds.data()+x_inds.size()); + std::for_each(stacks.begin(), stacks.end(), [&xy_inds](std::shared_ptr& stack) + { + stack->add_subset(xy_inds); + }); - SGVector inverted_inds(n_x+n_y); - std::fill(inverted_inds.data(), inverted_inds.data()+n_x+n_y, -1); - for (int idx=0; idx inverted_inds(n_x+n_y); + std::fill(inverted_inds.data(), inverted_inds.data()+n_x+n_y, -1); + for (int idx=0; idx stat_terms(kernel_mgr.num_kernels()); + for (auto col=0; colkernel(row, col), inverted_row, inverted_col); + for (size_t k=0; kkernel(row, col); + add_term(stat_terms[k], kernel_value, inverted_row, inverted_col); + } } } } - statistic[k]=compute_mmd(stat_terms); - } - - // compute the null samples - for (auto n=0; n(); - stack->add_subset(xy_inds); + for (size_t k=0; k permutation_inds(xy_inds.size()); - std::iota(permutation_inds.data(), permutation_inds.data()+permutation_inds.size(), 0); - CMath::permute(permutation_inds); - stack->add_subset(permutation_inds); + #pragma omp for + for (auto n=0; n permutation_inds(xy_inds.size()); + std::iota(permutation_inds.data(), permutation_inds.data()+permutation_inds.size(), 0); + CMath::permute(permutation_inds); - SGVector inds=stack->get_last_subset()->get_subset_idx(); + stacks[n]->add_subset(permutation_inds); + SGVector inds=stacks[n]->get_last_subset()->get_subset_idx(); + stacks[n]->remove_subset(); - SGVector inverted_permutation_inds(n_x+n_y); - std::fill(inverted_permutation_inds.data(), inverted_permutation_inds.data()+n_x+n_y, -1); - for (int idx=0; idx inverted_permutation_inds(n_x+n_y); + std::fill(inverted_permutation_inds.data(), inverted_permutation_inds.data()+n_x+n_y, -1); + for (int idx=0; idx terms(kernel_mgr.num_kernels()); + for (auto col=0; colkernel(row, col), permuted_row, permuted_col); + for (size_t k=0; kkernel(row, col); + add_term(terms[k], kernel_value, permuted_row, permuted_col); + } } } } - null_samples(k, n)=compute_mmd(terms); + for (size_t k=0; k null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false); - std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size()); -// null_samples_k.display_vector("null_samples_k"); - SG_SDEBUG("statistic=%f\n", statistic[k]); - float64_t idx=null_samples_k.find_position_to_insert(statistic[k]); - SG_SDEBUG("index=%f\n", idx); - auto p_value=1.0-idx/num_null_samples; - bool rejected=p_value& stack) + { + stack->remove_subset(); + }); + + // transpose the null_samples matrix for faster access + MatrixXd transposed_null_samples=null_samples_map.transpose(); + // cout << transposed_null_samples << endl; + //#pragma omp for + for (size_t k=0; k null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false); + std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size()); + // null_samples_k.display_vector("null_samples_k"); + SG_SDEBUG("statistic=%f\n", statistic[k]); + float64_t idx=null_samples_k.find_position_to_insert(statistic[k]); + SG_SDEBUG("index=%f\n", idx); + auto p_value=1.0-idx/num_null_samples; + bool rejected=p_valueset_statistic_type(ST_BIASED_FULL); mmd->set_null_approximation_method(NAM_PERMUTATION); mmd->set_num_null_samples(10); -// mmd->io->set_loglevel(MSG_DEBUG); for (auto i=0, sigma=-5; iset_train_test_mode(true); mmd->set_train_test_ratio(train_test_ratio); mmd->select_kernel(); - mmd->get_kernel_selection_strategy()->get_measure_matrix().display_matrix(); mmd->set_train_test_mode(false); auto selected_kernel=static_cast(mmd->get_kernel());