diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index 3632fd875f2..e01063f1988 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -68,6 +68,7 @@ struct CMMD::Self void compute_jobs(ComputationManager&) const; std::pair compute_statistic_variance(); + std::pair, SGMatrix> compute_statistic_and_Q(); SGVector sample_null(); CMMD& owner; @@ -252,6 +253,73 @@ std::pair CMMD::Self::compute_statistic_variance() return std::make_pair(statistic, variance); } +std::pair, SGMatrix> CMMD::Self::compute_statistic_and_Q() +{ + REQUIRE(kernel_selection_mgr.num_kernels()>0, "No kernels specified for kernel learning! " + "Please add kernels using add_kernel() method!\n"); + + const size_t num_kernels=kernel_selection_mgr.num_kernels(); + SGVector statistic(num_kernels); + SGMatrix Q(num_kernels, num_kernels); + + std::fill(statistic.data(), statistic.data()+statistic.size(), 0); + std::fill(Q.data(), Q.data()+Q.size(), 0); + + std::vector term_counters_statistic(num_kernels, 1); + SGMatrix term_counters_Q(num_kernels, num_kernels); + std::fill(term_counters_Q.data(), term_counters_Q.data()+term_counters_Q.size(), 1); + + DataManager& dm=owner.get_data_manager(); + ComputationManager cm; + create_computation_jobs(); + cm.enqueue_job(statistic_job); + + dm.start(); + auto next_burst=dm.next(); + std::vector blocks; + std::vector> mmds(num_kernels); + while (!next_burst.empty()) + { + merge_samples(next_burst, blocks); + REQUIRE(blocks.size()%2==0, "The number of blocks per burst (%d this burst) has to be even!\n", blocks.size()); + for (size_t k=0; k CMMD::Self::sample_null() { const KernelManager& km=owner.get_kernel_manager(); @@ -341,7 +409,8 @@ void CMMD::select_kernel(EKernelSelectionMethod kmethod) break; } default: - SG_ERROR("Unsupported kernel selection method specified!\n"); + SG_ERROR("Unsupported kernel selection method specified! " + "Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER!\n"); break; } SG_DEBUG("Leaving!\n");