Skip to content

Commit

Permalink
added statistic and Q computation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent 0f8d00f commit 763a616
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -68,6 +68,7 @@ struct CMMD::Self
void compute_jobs(ComputationManager&) const;

std::pair<float64_t, float64_t> compute_statistic_variance();
std::pair<SGVector<float64_t>, SGMatrix<float64_t>> compute_statistic_and_Q();
SGVector<float64_t> sample_null();

CMMD& owner;
Expand Down Expand Up @@ -252,6 +253,73 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()
return std::make_pair(statistic, variance);
}

std::pair<SGVector<float64_t>, SGMatrix<float64_t>> CMMD::Self::compute_statistic_and_Q()
{
REQUIRE(kernel_selection_mgr.num_kernels()>0, "No kernels specified for kernel learning! "
"Please add kernels using add_kernel() method!\n");

const size_t num_kernels=kernel_selection_mgr.num_kernels();
SGVector<float64_t> statistic(num_kernels);
SGMatrix<float64_t> Q(num_kernels, num_kernels);

std::fill(statistic.data(), statistic.data()+statistic.size(), 0);
std::fill(Q.data(), Q.data()+Q.size(), 0);

std::vector<index_t> term_counters_statistic(num_kernels, 1);
SGMatrix<index_t> term_counters_Q(num_kernels, num_kernels);
std::fill(term_counters_Q.data(), term_counters_Q.data()+term_counters_Q.size(), 1);

DataManager& dm=owner.get_data_manager();
ComputationManager cm;
create_computation_jobs();
cm.enqueue_job(statistic_job);

dm.start();
auto next_burst=dm.next();
std::vector<CFeatures*> blocks;
std::vector<std::vector<float32_t>> mmds(num_kernels);
while (!next_burst.empty())
{
merge_samples(next_burst, blocks);
REQUIRE(blocks.size()%2==0, "The number of blocks per burst (%d this burst) has to be even!\n", blocks.size());
for (size_t k=0; k<num_kernels; ++k)
{
CKernel* kernel=kernel_selection_mgr.kernel_at(k);
compute_kernel(cm, blocks, kernel);
compute_jobs(cm);
mmds[k]=cm.result(0);
for (size_t i=0; i<mmds[k].size(); ++i)
{
auto delta=mmds[k][i]-statistic[k];
statistic[k]+=delta/term_counters_statistic[k]++;
}
}
for (size_t i=0; i<num_kernels; ++i)
{
for (size_t j=0; j<=i; ++j)
{
for (size_t k=0; k<blocks.size()-1; k+=2)
{
auto term=(mmds[i][k]-mmds[i][k+1])*(mmds[i][k]-mmds[i][k+1]);
Q(i, j)+=(term-Q(i, j))/term_counters_Q(i, j)++;
}
Q(j, i)=Q(i, j);
}
}
next_burst=dm.next();
}
mmds.clear();

dm.end();
cm.done();

std::for_each(statistic.data(), statistic.data()+statistic.size(), [this](float64_t val)
{
val=owner.normalize_statistic(val);
});
return std::make_pair(statistic, Q);
}

SGVector<float64_t> CMMD::Self::sample_null()
{
const KernelManager& km=owner.get_kernel_manager();
Expand Down Expand Up @@ -341,7 +409,8 @@ void CMMD::select_kernel(EKernelSelectionMethod kmethod)
break;
}
default:
SG_ERROR("Unsupported kernel selection method specified!\n");
SG_ERROR("Unsupported kernel selection method specified! "
"Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER!\n");
break;
}
SG_DEBUG("Leaving!\n");
Expand Down

0 comments on commit 763a616

Please sign in to comment.