Skip to content

Commit

Permalink
fixed alpha bug ;)
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 4, 2016
1 parent 44a7ba4 commit 3009842
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
Expand Up @@ -232,16 +232,20 @@ void MultiKernelPermutationTestCrossValidation::operator()(const KernelManager&

// transpose the null_samples matrix for faster access
MatrixXd transposed_null_samples=null_samples_map.transpose();
// cout << transposed_null_samples << endl;
#pragma omp parallel for
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
{
SGVector<float64_t> null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false);
std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size());
// null_samples_k.display_vector("null_samples_k");
SG_SDEBUG("statistic=%f\n", statistic[k]);
float64_t idx=null_samples_k.find_position_to_insert(statistic[k]);
SG_SDEBUG("index=%f\n", idx);
auto p_value=1.0-idx/num_null_samples;
SG_SDEBUG("p-value=%f, rejected=%d\n", p_value, p_value<alpha);
rejections(i*num_folds+j, k)=p_value<alpha;
bool rejected=p_value<alpha;
SG_SDEBUG("p-value=%f, alpha=%f, rejected=%d\n", p_value, alpha, rejected);
rejections(i*num_folds+j, k)=rejected;
}
}
}
Expand All @@ -260,7 +264,7 @@ void MultiKernelPermutationTestCrossValidation::set_num_folds(index_t nf)
num_folds=nf;
}

void MultiKernelPermutationTestCrossValidation::set_alpha(index_t alp)
void MultiKernelPermutationTestCrossValidation::set_alpha(float64_t alp)
{
alpha=alp;
}
Expand Down
Expand Up @@ -62,7 +62,7 @@ class MultiKernelPermutationTestCrossValidation

void set_num_runs(index_t nr);
void set_num_folds(index_t nf);
void set_alpha(index_t alp);
void set_alpha(float64_t alp);
void set_measure_matrix(SGMatrix<float64_t> measures);
void set_distance(CCustomDistance* distance);
private:
Expand Down
Expand Up @@ -258,7 +258,7 @@ void PermutationTestCrossValidation::set_num_folds(index_t nf)
num_folds=nf;
}

void PermutationTestCrossValidation::set_alpha(index_t alp)
void PermutationTestCrossValidation::set_alpha(float64_t alp)
{
alpha=alp;
}
Expand Down
Expand Up @@ -56,7 +56,7 @@ class PermutationTestCrossValidation
template <typename T> void operator()(const SGMatrix<T>& km, index_t k);
void set_num_runs(index_t nr);
void set_num_folds(index_t nf);
void set_alpha(index_t alp);
void set_alpha(float64_t alp);
void set_measure_matrix(SGMatrix<float64_t> measures);
private:
struct terms_t;
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/statistical_testing/KernelSelection_unittest.cc
Expand Up @@ -238,7 +238,7 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense)
auto mmd=some<CQuadraticTimeMMD>(feats_p, feats_q);
mmd->set_statistic_type(ST_BIASED_FULL);
mmd->set_null_approximation_method(NAM_PERMUTATION);
mmd->set_num_null_samples(1);
mmd->set_num_null_samples(10);
// mmd->io->set_loglevel(MSG_DEBUG);
for (auto i=0, sigma=-5; i<num_kernels; ++i, sigma+=1)
{
Expand Down

0 comments on commit 3009842

Please sign in to comment.