Skip to content

Commit

Permalink
removed simulate_null
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 4, 2016
1 parent 2cf82b3 commit 518c72c
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -50,12 +50,13 @@ struct CMMD::Self
void create_variance_job(index_t Bx);

std::pair<float64_t, float64_t> compute_statistic_variance();
SGVector<float64_t> sample_null();

CMMD& owner;

bool use_gpu_for_computation;
bool simulate_null;
index_t num_null_samples;

EStatisticType statistic_type;
EVarianceEstimationMethod variance_estimation_method;
ENullApproximationMethod null_approximation_method;
Expand All @@ -65,9 +66,11 @@ struct CMMD::Self
};

CMMD::Self::Self(CMMD& cmmd) : owner(cmmd),
use_gpu_for_computation(false), simulate_null(false), num_null_samples(0),
statistic_type(EStatisticType::UNBIASED_FULL), variance_estimation_method(EVarianceEstimationMethod::DIRECT),
null_approximation_method(ENullApproximationMethod::PERMUTATION), statistic_job(nullptr), variance_job(nullptr)
use_gpu_for_computation(false), num_null_samples(250),
statistic_type(EStatisticType::UNBIASED_FULL),
variance_estimation_method(EVarianceEstimationMethod::DIRECT),
null_approximation_method(ENullApproximationMethod::PERMUTATION),
statistic_job(nullptr), variance_job(nullptr)
{
}

Expand Down Expand Up @@ -102,7 +105,7 @@ void CMMD::Self::create_variance_job(index_t Bx)
variance_job = owner.get_direct_estimation_method();
break;
case EVarianceEstimationMethod::PERMUTATION:
switch(statistic_type)
switch (statistic_type)
{
case EStatisticType::UNBIASED_FULL:
variance_job = mmd::WithinBlockPermutation<mmd::UnbiasedFull>(Bx);
Expand Down Expand Up @@ -155,13 +158,6 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()

auto block_p_q = block_p->create_merged_copy(block_q.get());
SG_REF(block_p_q);
if (simulate_null)
{
SGVector<index_t> inds(block_p_q->get_num_vectors());
std::iota(inds.vector, inds.vector + inds.vlen, 0);
CMath::permute(inds);
block_p_q->add_subset(inds);
}

block_p = nullptr;
block_q = nullptr;
Expand Down Expand Up @@ -242,6 +238,11 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()
return std::make_pair(statistic, variance);
}

SGVector<float64_t> CMMD::Self::sample_null()
{
return SGVector<float64_t>();
}

CMMD::CMMD() : CTwoSampleTest()
{
self = std::unique_ptr<Self>(new Self(*this));
Expand All @@ -263,15 +264,7 @@ float64_t CMMD::compute_variance()

SGVector<float64_t> CMMD::sample_null()
{
SGVector<float64_t> null_samples(self->num_null_samples);
auto old = self->simulate_null;
self->simulate_null = true;
for (auto i = 0; i < self->num_null_samples; ++i)
{
null_samples[i] = compute_statistic();
}
self->simulate_null = old;
return null_samples;
return self->sample_null();
}

void CMMD::set_num_null_samples(index_t null_samples)
Expand Down

0 comments on commit 518c72c

Please sign in to comment.