Skip to content

Commit

Permalink
removed statistic computation API for multiple kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 29, 2016
1 parent 5a6b463 commit 66b6fdb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 125 deletions.
172 changes: 54 additions & 118 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -49,7 +49,7 @@ struct CMMD::Self
void create_statistic_job(index_t Bx);
void create_variance_job(index_t Bx);

std::pair<SGVector<float64_t>, SGVector<float64_t>> compute_statistic_variance();
std::pair<float64_t, float64_t> compute_statistic_variance();

CMMD& owner;

Expand Down Expand Up @@ -120,47 +120,19 @@ void CMMD::Self::create_variance_job(index_t Bx)
};
}

std::pair<SGVector<float64_t>, SGVector<float64_t>> CMMD::Self::compute_statistic_variance()
std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()
{
DataManager& dm = owner.get_data_manager();
const KernelManager& km = owner.get_kernel_manager();

SGVector<float64_t> statistic;
SGVector<float64_t> stat_perm;
SGVector<float64_t> variance;
float64_t statistic = 0;
float64_t permuted_samples_statistic = 0;
float64_t variance = 0;

auto kernel = km.kernel_at(0);
ASSERT(kernel != nullptr);
auto num_kernels = 1;
REQUIRE(kernel != nullptr, "Kernel is not set!\n");

std::vector<CKernel*> kernels;

if (kernel->get_kernel_type() == K_COMBINED)
{
auto combined_kernel = static_cast<CCombinedKernel*>(kernel);
num_kernels = combined_kernel->get_num_subkernels();

kernels = std::vector<CKernel*>(num_kernels);
for (auto i = 0; i < num_kernels; ++i)
{
kernels[i] = combined_kernel->get_kernel(i);
}
}
else
{
kernels.push_back(kernel);
}

statistic = SGVector<float64_t>(num_kernels);
stat_perm = SGVector<float64_t>(num_kernels);
variance = SGVector<float64_t>(num_kernels);

std::fill(statistic.vector, statistic.vector + statistic.vlen, 0);
std::fill(stat_perm.vector, stat_perm.vector + stat_perm.vlen, 0);
std::fill(variance.vector, variance.vector + variance.vlen, 0);

std::vector<index_t> term_counters(statistic.vlen);
std::fill(term_counters.data(), term_counters.data() + term_counters.size(), 1);
index_t term_counters = 1;

ComputationManager cm;
dm.start();
Expand All @@ -176,7 +148,7 @@ std::pair<SGVector<float64_t>, SGVector<float64_t>> CMMD::Self::compute_statisti
blocks.resize(next_burst.num_blocks());

#pragma omp parallel for
for (auto i = 0; i < next_burst.num_blocks(); ++i)
for (size_t i = 0; i < blocks.size(); ++i)
{
auto block_p = next_burst[0][i];
auto block_q = next_burst[1][i];
Expand All @@ -197,83 +169,74 @@ std::pair<SGVector<float64_t>, SGVector<float64_t>> CMMD::Self::compute_statisti
blocks[i] = std::shared_ptr<CFeatures>(block_p_q, [](CFeatures* ptr) { SG_UNREF(ptr); });
}

for (auto i = 0; i < kernels.size(); ++i)
{
#pragma omp parallel for
for (auto j = 0; j < blocks.size(); ++j)
{
try
{
auto curr_kernel = std::unique_ptr<CKernel>(static_cast<CKernel*>(kernels[i]->clone()));
curr_kernel->init(blocks[j].get(), blocks[j].get());
cm.data(j) = std::unique_ptr<CCustomKernel>(new CCustomKernel(curr_kernel.get()))->get_kernel_matrix();
curr_kernel->remove_lhs_and_rhs();
}
catch (ShogunException e)
{
SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string());
}
}

// enqueue statistic and variance computation jobs on the computed kernel matrices
cm.enqueue_job(statistic_job);
cm.enqueue_job(variance_job);

if (use_gpu_for_computation)
for (size_t i = 0; i < blocks.size(); ++i)
{
try
{
cm.use_gpu().compute();
auto kernel_clone = std::unique_ptr<CKernel>(static_cast<CKernel*>(kernel->clone()));
kernel_clone->init(blocks[i].get(), blocks[i].get());
cm.data(i) = std::unique_ptr<CCustomKernel>(new CCustomKernel(kernel_clone.get()))->get_kernel_matrix();
kernel_clone->remove_lhs_and_rhs();
}
else
catch (ShogunException e)
{
cm.use_cpu().compute();
SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string());
}
}

auto mmds = cm.next_result();
auto vars = cm.next_result();
// enqueue statistic and variance computation jobs on the computed kernel matrices
cm.enqueue_job(statistic_job);
cm.enqueue_job(variance_job);

for (auto j = 0; j < mmds.size(); ++j)
{
auto delta = mmds[j] - statistic[i];
statistic[i] += delta / term_counters[i];
}
if (use_gpu_for_computation)
{
cm.use_gpu().compute();
}
else
{
cm.use_cpu().compute();
}

if (variance_estimation_method == EVarianceEstimationMethod::DIRECT)
auto mmds = cm.next_result();
auto vars = cm.next_result();

for (size_t i = 0; i < mmds.size(); ++i)
{
auto delta = mmds[i] - statistic;
statistic += delta / term_counters;
}

if (variance_estimation_method == EVarianceEstimationMethod::DIRECT)
{
for (size_t i = 0; i < mmds.size(); ++i)
{
for (auto j = 0; j < mmds.size(); ++j)
{
auto delta = vars[j] - variance[i];
variance[i] += delta / term_counters[i];
}
auto delta = vars[i] - variance;
variance += delta / term_counters;
}
else
}
else
{
for (size_t i = 0; i < mmds.size(); ++i)
{
for (auto j = 0; j < mmds.size(); ++j)
{
auto delta = vars[j] - stat_perm[i];
stat_perm[i] += delta / term_counters[i];
variance[i] += delta * (vars[j] - stat_perm[i]);
}
auto delta = vars[i] - permuted_samples_statistic;
permuted_samples_statistic += delta / term_counters;
variance += delta * (vars[i] - permuted_samples_statistic);
}
term_counters[i]++;
}
term_counters++;

next_burst = dm.next();
}

dm.end();

// normalize statistic and variance
std::for_each(statistic.vector, statistic.vector + statistic.vlen, [this](float64_t& v)
{
v = owner.normalize_statistic(v);
});
statistic = owner.normalize_statistic(statistic);

if (variance_estimation_method == EVarianceEstimationMethod::PERMUTATION)
{
std::for_each(variance.vector, variance.vector + variance.vlen, [this](float64_t& v)
{
v = owner.normalize_variance(v);
});
variance = owner.normalize_variance(variance);
}

return std::make_pair(statistic, variance);
Expand All @@ -290,33 +253,11 @@ CMMD::~CMMD()

float64_t CMMD::compute_statistic()
{
return self->compute_statistic_variance().first[0];
}

float64_t CMMD::compute_variance()
{
return self->compute_statistic_variance().second[0];
}

SGVector<float64_t> CMMD::compute_statistic(bool multiple_kernels)
{
if (multiple_kernels)
{
const KernelManager& km = get_kernel_manager();
auto kernel = km.kernel_at(0);
ASSERT(kernel->get_kernel_type() == K_COMBINED);
}
return self->compute_statistic_variance().first;
}

SGVector<float64_t> CMMD::compute_variance(bool multiple_kernels)
float64_t CMMD::compute_variance()
{
if (multiple_kernels)
{
const KernelManager& km = get_kernel_manager();
auto kernel = km.kernel_at(0);
ASSERT(kernel->get_kernel_type() == K_COMBINED);
}
return self->compute_statistic_variance().second;
}

Expand Down Expand Up @@ -348,11 +289,6 @@ void CMMD::use_gpu(bool gpu)
self->use_gpu_for_computation = gpu;
}

void CMMD::set_simulate_null(bool simulate_null)
{
self->simulate_null = simulate_null;
}

void CMMD::set_statistic_type(EStatisticType stype)
{
self->statistic_type = stype;
Expand Down
21 changes: 14 additions & 7 deletions src/shogun/statistical_testing/MMD.h
Expand Up @@ -51,34 +51,41 @@ enum class ENullApproximationMethod
MMD2_GAMMA
};

enum class EKernelSelectionMethod
{
MEDIAN_HEURISRIC,
MAXIMIZE_MMD,
MAXIMIZE_POWER
};

class CMMD : public CTwoSampleTest
{
using operation = std::function<float64_t(SGMatrix<float64_t>)>;
public:
CMMD();
virtual ~CMMD();

/*
void add_kernel(CKernel *kernel);
void select_kernel(EKernelSelectionMethod kmethod);
CKernel* get_kernel() const;
*/
virtual float64_t compute_statistic() override;
SGVector<float64_t> compute_statistic(bool multiple_kernels);

float64_t compute_variance();
SGVector<float64_t> compute_variance(bool multiple_kernels);

void set_statistic_type(EStatisticType stype);
const EStatisticType get_statistic_type() const;

void set_variance_estimation_method(EVarianceEstimationMethod vmethod);
const EVarianceEstimationMethod get_variance_estimation_method() const;

void set_simulate_null(bool simulate_null);
void set_num_null_samples(index_t null_samples);
const index_t get_num_null_samples() const;

virtual SGVector<float64_t> sample_null() override;

void set_null_approximation_method(ENullApproximationMethod nmethod);
const ENullApproximationMethod get_null_approximation_method() const;

virtual SGVector<float64_t> sample_null() override;

void use_gpu(bool gpu);

virtual const char* get_name() const;
Expand Down

0 comments on commit 66b6fdb

Please sign in to comment.