Skip to content

Commit

Permalink
fixed quadratic time mmd bugs and speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent dcd3641 commit 90c6174
Show file tree
Hide file tree
Showing 8 changed files with 1,237 additions and 182 deletions.
14 changes: 9 additions & 5 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -47,7 +47,6 @@ struct CMMD::Self

void create_statistic_job();
void create_variance_job();

void create_computation_jobs();

void merge_samples(NextSamples&, std::vector<std::shared_ptr<CFeatures>>&) const;
Expand All @@ -59,7 +58,7 @@ struct CMMD::Self

CMMD& owner;

bool use_gpu_for_computation;
bool use_gpu;
index_t num_null_samples;

EStatisticType statistic_type;
Expand All @@ -72,7 +71,7 @@ struct CMMD::Self
};

CMMD::Self::Self(CMMD& cmmd) : owner(cmmd),
use_gpu_for_computation(false), num_null_samples(250),
use_gpu(false), num_null_samples(250),
statistic_type(EStatisticType::UNBIASED_FULL),
variance_estimation_method(EVarianceEstimationMethod::DIRECT),
null_approximation_method(ENullApproximationMethod::PERMUTATION),
Expand Down Expand Up @@ -169,7 +168,7 @@ void CMMD::Self::compute_kernel(ComputationManager& cm, std::vector<std::shared_

void CMMD::Self::compute_jobs(ComputationManager& cm) const
{
if (use_gpu_for_computation)
if (use_gpu)
{
cm.use_gpu().compute();
}
Expand Down Expand Up @@ -345,7 +344,12 @@ const index_t CMMD::get_num_null_samples() const

void CMMD::use_gpu(bool gpu)
{
self->use_gpu_for_computation = gpu;
self->use_gpu = gpu;
}

bool CMMD::use_gpu() const
{
return self->use_gpu;
}

void CMMD::set_statistic_type(EStatisticType stype)
Expand Down
3 changes: 2 additions & 1 deletion src/shogun/statistical_testing/MMD.h
Expand Up @@ -70,7 +70,7 @@ class CMMD : public CTwoSampleTest
CKernel* get_kernel() const;
*/
virtual float64_t compute_statistic() override;
float64_t compute_variance();
virtual float64_t compute_variance();

void set_statistic_type(EStatisticType stype);
const EStatisticType get_statistic_type() const;
Expand All @@ -93,6 +93,7 @@ class CMMD : public CTwoSampleTest
virtual const operation get_direct_estimation_method() const = 0;
virtual const float64_t normalize_statistic(float64_t statistic) const = 0;
virtual const float64_t normalize_variance(float64_t variance) const = 0;
bool use_gpu() const;
private:
struct Self;
std::unique_ptr<Self> self;
Expand Down

0 comments on commit 90c6174

Please sign in to comment.