From 518c72c9a2cfe2ffba4d574cf7422bc712275182 Mon Sep 17 00:00:00 2001 From: lambday Date: Fri, 1 Apr 2016 02:35:42 +0530 Subject: [PATCH] removed simulate_null --- src/shogun/statistical_testing/MMD.cpp | 35 +++++++++++--------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index 4b4452b9a5b..fecb88666fc 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -50,12 +50,13 @@ struct CMMD::Self void create_variance_job(index_t Bx); std::pair compute_statistic_variance(); + SGVector sample_null(); CMMD& owner; bool use_gpu_for_computation; - bool simulate_null; index_t num_null_samples; + EStatisticType statistic_type; EVarianceEstimationMethod variance_estimation_method; ENullApproximationMethod null_approximation_method; @@ -65,9 +66,11 @@ struct CMMD::Self }; CMMD::Self::Self(CMMD& cmmd) : owner(cmmd), - use_gpu_for_computation(false), simulate_null(false), num_null_samples(0), - statistic_type(EStatisticType::UNBIASED_FULL), variance_estimation_method(EVarianceEstimationMethod::DIRECT), - null_approximation_method(ENullApproximationMethod::PERMUTATION), statistic_job(nullptr), variance_job(nullptr) + use_gpu_for_computation(false), num_null_samples(250), + statistic_type(EStatisticType::UNBIASED_FULL), + variance_estimation_method(EVarianceEstimationMethod::DIRECT), + null_approximation_method(ENullApproximationMethod::PERMUTATION), + statistic_job(nullptr), variance_job(nullptr) { } @@ -102,7 +105,7 @@ void CMMD::Self::create_variance_job(index_t Bx) variance_job = owner.get_direct_estimation_method(); break; case EVarianceEstimationMethod::PERMUTATION: - switch(statistic_type) + switch (statistic_type) { case EStatisticType::UNBIASED_FULL: variance_job = mmd::WithinBlockPermutation(Bx); @@ -155,13 +158,6 @@ std::pair CMMD::Self::compute_statistic_variance() auto block_p_q = block_p->create_merged_copy(block_q.get()); SG_REF(block_p_q); - if (simulate_null) - { - SGVector inds(block_p_q->get_num_vectors()); - std::iota(inds.vector, inds.vector + inds.vlen, 0); - CMath::permute(inds); - block_p_q->add_subset(inds); - } block_p = nullptr; block_q = nullptr; @@ -242,6 +238,11 @@ std::pair CMMD::Self::compute_statistic_variance() return std::make_pair(statistic, variance); } +SGVector CMMD::Self::sample_null() +{ + return SGVector(); +} + CMMD::CMMD() : CTwoSampleTest() { self = std::unique_ptr(new Self(*this)); @@ -263,15 +264,7 @@ float64_t CMMD::compute_variance() SGVector CMMD::sample_null() { - SGVector null_samples(self->num_null_samples); - auto old = self->simulate_null; - self->simulate_null = true; - for (auto i = 0; i < self->num_null_samples; ++i) - { - null_samples[i] = compute_statistic(); - } - self->simulate_null = old; - return null_samples; + return self->sample_null(); } void CMMD::set_num_null_samples(index_t null_samples)