Skip to content

Commit

Permalink
refactored maximize cross validation kernel selection method
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent 0a8cf98 commit 186ec36
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 41 deletions.
2 changes: 2 additions & 0 deletions src/shogun/statistical_testing/internals/KernelSelection.h
Expand Up @@ -60,6 +60,8 @@ class KernelSelection
protected:
const KernelManager& kernel_mgr;
CMMD* estimator;
virtual void init_measures()=0;
virtual void compute_measures()=0;
};

}
Expand Down
3 changes: 1 addition & 2 deletions src/shogun/statistical_testing/internals/MaxMeasure.h
Expand Up @@ -57,10 +57,9 @@ class MaxMeasure : public KernelSelection
virtual SGVector<float64_t> get_measure_vector();
virtual SGMatrix<float64_t> get_measure_matrix();
protected:
virtual void init_measures();
virtual void compute_measures();
SGVector<float64_t> measures;

virtual void init_measures();
};

}
Expand Down
73 changes: 42 additions & 31 deletions src/shogun/statistical_testing/internals/MaxXValidation.cpp
Expand Up @@ -54,59 +54,70 @@ MaxXValidation::~MaxXValidation()

SGVector<float64_t> MaxXValidation::get_measure_vector()
{
SG_SNOTIMPLEMENTED;
return SGVector<float64_t>();
return measures;
}

SGMatrix<float64_t> MaxXValidation::get_measure_matrix()
{
SG_SNOTIMPLEMENTED;
return SGMatrix<float64_t>();
return rejections;
}

void MaxXValidation::compute_measures(SGVector<float64_t>& measures, SGVector<index_t>& term_counters)
void MaxXValidation::init_measures()
{
const size_t num_kernels=kernel_mgr.num_kernels();
for (size_t i=0; i<num_kernels; ++i)
{
auto kernel=kernel_mgr.kernel_at(i);
estimator->set_kernel(kernel);
bool rejected=estimator->compute_p_value(estimator->compute_statistic())<alpha;
auto delta=measures[i]-rejected;
measures[i]=delta/term_counters[i]++;
estimator->cleanup();
}
const index_t num_kernels=kernel_mgr.num_kernels();
auto& dm=estimator->get_data_manager();
const index_t N=dm.get_num_folds();
REQUIRE(N!=0, "Number of folds is not set!\n");
if (rejections.num_rows!=N*num_run || rejections.num_cols!=num_kernels)
rejections=SGMatrix<float64_t>(N*num_run, num_kernels);
std::fill(rejections.data(), rejections.data()+rejections.size(), 0);
if (measures.size()!=num_kernels)
measures=SGVector<float64_t>(num_kernels);
std::fill(measures.data(), measures.data()+measures.size(), 0);
}

CKernel* MaxXValidation::select_kernel()
void MaxXValidation::compute_measures()
{
auto& dm=estimator->get_data_manager();
dm.set_xvalidation_mode(true);
auto existing_kernel=estimator->get_kernel();

const index_t N=dm.get_num_folds();
// TODO write a more meaningful error message
REQUIRE(N!=0, "Number of folds is not set!\n");
SG_SINFO("Performing %d fold cross-validattion!\n", N);
// train mode is already ON by now! set by the caller
SGVector<float64_t> measures(kernel_mgr.num_kernels());
std::fill(measures.data(), measures.data()+measures.size(), 0);
SGVector<index_t> term_counters(measures.size());
std::fill(term_counters.data(), term_counters.data()+term_counters.size(), 1);

const size_t num_kernels=kernel_mgr.num_kernels();
auto existing_kernel=estimator->get_kernel();
for (auto i=0; i<num_run; ++i)
{
// TODO set permutation beforehand
for (auto j=0; j<N; ++j)
{
dm.use_fold(j);
compute_measures(measures, term_counters);
for (size_t k=0; k<num_kernels; ++k)
{
auto kernel=kernel_mgr.kernel_at(k);
estimator->set_kernel(kernel);
rejections(i*N+j, k)=estimator->compute_p_value(estimator->compute_statistic())<alpha;
estimator->cleanup();
}
}
}

estimator->set_kernel(existing_kernel);
dm.set_xvalidation_mode(false);
estimator->set_kernel(existing_kernel);

auto min_element=std::min_element(measures.vector, measures.vector+measures.vlen);
auto min_idx=std::distance(measures.vector, min_element);
SG_SDEBUG("Selected kernel at %d position!\n", min_idx);
return kernel_mgr.kernel_at(min_idx);
for (auto j=0; j<rejections.num_cols; ++j)
{
auto begin=rejections.get_column_vector(j);
auto size=rejections.num_rows;
measures[j]=std::accumulate(begin, begin+size, 0)/size;
}
}

CKernel* MaxXValidation::select_kernel()
{
init_measures();
compute_measures();
auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
auto max_idx=std::distance(measures.vector, max_element);
SG_SDEBUG("Selected kernel at %d position!\n", max_idx);
return kernel_mgr.kernel_at(max_idx);
}
5 changes: 4 additions & 1 deletion src/shogun/statistical_testing/internals/MaxXValidation.h
Expand Up @@ -56,9 +56,12 @@ class MaxXValidation : public KernelSelection
virtual SGVector<float64_t> get_measure_vector();
virtual SGMatrix<float64_t> get_measure_matrix();
protected:
void compute_measures(SGVector<float64_t>&, SGVector<index_t>&);
virtual void init_measures();
virtual void compute_measures();
const index_t num_run;
const float64_t alpha;
SGMatrix<float64_t> rejections;
SGVector<float64_t> measures;
};

}
Expand Down
14 changes: 9 additions & 5 deletions src/shogun/statistical_testing/internals/MedianHeuristic.cpp
Expand Up @@ -58,13 +58,18 @@ MedianHeuristic::~MedianHeuristic()
}

void MedianHeuristic::init_measures()
{
SG_SNOTIMPLEMENTED;
}

void MedianHeuristic::compute_measures()
{
distance=estimator->compute_distance();
SG_REF(distance);
n=distance->get_num_vec_lhs();
REQUIRE(distance->get_num_vec_lhs()==distance->get_num_vec_rhs(),
"Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n",
distance->get_num_vec_lhs(), distance->get_num_vec_rhs());
"Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n",
distance->get_num_vec_lhs(), distance->get_num_vec_rhs());
measures=SGVector<float64_t>((n*(n-1))/2);
size_t write_idx=0;
for (auto j=0; j<n; ++j)
Expand All @@ -77,8 +82,7 @@ void MedianHeuristic::init_measures()

SGVector<float64_t> MedianHeuristic::get_measure_vector()
{
SG_SNOTIMPLEMENTED;
return SGVector<float64_t>();
return measures;
}

SGMatrix<float64_t> MedianHeuristic::get_measure_matrix()
Expand All @@ -88,7 +92,7 @@ SGMatrix<float64_t> MedianHeuristic::get_measure_matrix()

CKernel* MedianHeuristic::select_kernel()
{
init_measures();
compute_measures();
auto median_distance=measures[measures.size()/2];
SG_SDEBUG("kernel width (shogun): %f\n", median_distance);

Expand Down
4 changes: 2 additions & 2 deletions src/shogun/statistical_testing/internals/MedianHeuristic.h
Expand Up @@ -58,8 +58,8 @@ class MedianHeuristic : public KernelSelection
virtual SGVector<float64_t> get_measure_vector();
virtual SGMatrix<float64_t> get_measure_matrix();
protected:
void init_measures();
void compute_measures();
virtual void init_measures();
virtual void compute_measures();
CCustomDistance* distance;
SGVector<float64_t> measures;
int32_t n;
Expand Down

0 comments on commit 186ec36

Please sign in to comment.