Skip to content

Commit

Permalink
bugfix for train-test ratio in statistical testing (#4134)
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Feb 4, 2018
1 parent 93d06f8 commit 0b519ee
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
51 changes: 39 additions & 12 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -222,24 +222,36 @@ const bool DataManager::is_blockwise() const

void DataManager::set_train_test_mode(bool on)
{
train_test_mode=on;
if (!train_test_mode)
if (!on)
{
train_mode=default_train_mode;
train_test_ratio=default_train_test_ratio;
cross_validation_mode=default_cross_validation_mode;

set_train_mode(train_mode);
set_train_test_ratio(train_test_ratio);

train_test_mode = on;
REQUIRE(fetchers.size() > 0, "Features are not set!\n");
typedef std::unique_ptr<DataFetcher> fetcher_type;
std::for_each(
fetchers.begin(), fetchers.end(), [this](fetcher_type& f) {
f->set_train_test_mode(train_test_mode);
});
}
REQUIRE(fetchers.size()>0, "Features are not set!");
typedef std::unique_ptr<DataFetcher> fetcher_type;
std::for_each(fetchers.begin(), fetchers.end(), [this, on](fetcher_type& f)
else
{
f->set_train_test_mode(on);
if (on)
{
f->set_train_mode(train_mode);
f->set_train_test_ratio(train_test_ratio);
}
});
train_test_mode = on;
REQUIRE(fetchers.size() > 0, "Features are not set!\n");
typedef std::unique_ptr<DataFetcher> fetcher_type;
std::for_each(
fetchers.begin(), fetchers.end(), [this](fetcher_type& f) {
f->set_train_test_mode(train_test_mode);
});

set_train_mode(train_mode);
set_train_test_ratio(train_test_ratio);
}
}

bool DataManager::is_train_test_mode() const
Expand All @@ -250,7 +262,14 @@ bool DataManager::is_train_test_mode() const
void DataManager::set_train_mode(bool on)
{
if (train_test_mode)
{
train_mode=on;
REQUIRE(fetchers.size() > 0, "Features are not set!\n");
typedef std::unique_ptr<DataFetcher> fetcher_type;
std::for_each(
fetchers.begin(), fetchers.end(),
[this](fetcher_type& f) { f->set_train_mode(train_mode); });
}
else
{
SG_SERROR("Train mode cannot be used without turning on Train/Test mode first!"
Expand Down Expand Up @@ -282,7 +301,15 @@ bool DataManager::is_cross_validation_mode() const
void DataManager::set_train_test_ratio(float64_t ratio)
{
if (train_test_mode)
{
train_test_ratio=ratio;
REQUIRE(fetchers.size() > 0, "Features are not set!\n");
typedef std::unique_ptr<DataFetcher> fetcher_type;
std::for_each(
fetchers.begin(), fetchers.end(), [this](fetcher_type& f) {
f->set_train_test_ratio(train_test_ratio);
});
}
else
{
SG_SERROR("Train-test ratio cannot be set without turning on Train/Test mode first!"
Expand Down
11 changes: 10 additions & 1 deletion src/shogun/statistical_testing/internals/mmd/ComputeMMD.h
Expand Up @@ -87,7 +87,16 @@ struct ComputeMMD
{
ASSERT(m_n_x>0 && m_n_y>0);
const index_t size=m_n_x+m_n_y;
ASSERT(kernel_matrix.num_rows==size && kernel_matrix.num_cols==size);
REQUIRE(
kernel_matrix.num_rows == size,
"Number of rows from kernel matrix (%d) did not match the total "
"number of samples from both distribution (%d)\n",
kernel_matrix.num_rows, size);
REQUIRE(
kernel_matrix.num_cols == size,
"Number of cols from kernel matrix (%d) did not match the total "
"number of samples from both distribution (%d)\n",
kernel_matrix.num_cols, size);

typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> MatrixXt;
typedef Eigen::Block<Eigen::Map<const MatrixXt> > BlockXt;
Expand Down
35 changes: 34 additions & 1 deletion tests/unit/statistical_testing/KernelSelection_unittest.cc
Expand Up @@ -108,6 +108,39 @@ TEST(KernelSelectionMaxMMD, quadratic_time_single_kernel_dense)
EXPECT_NEAR(selected_kernel->get_width(), 0.25, 1E-10);
}

TEST(
KernelSelectionMaxMMD,
quadratic_time_single_kernel_dense_unequal_train_test_ratio)
{
const index_t m = 10;
const index_t n = 20;
const index_t dim = 1;
const float64_t difference = 0.5;
const index_t num_kernels = 10;
const float64_t train_test_ratio = 4;

sg_rand->set_seed(12345);

auto gen_p = some<CMeanShiftDataGenerator>(0, dim, 0);
auto gen_q = some<CMeanShiftDataGenerator>(difference, dim, 0);

auto feats_p = gen_p->get_streamed_features(m);
auto feats_q = gen_q->get_streamed_features(n);

auto mmd = some<CQuadraticTimeMMD>(feats_p, feats_q);
mmd->set_statistic_type(ST_BIASED_FULL);
for (auto i = 0, sigma = -5; i < num_kernels; ++i, sigma += 1)
{
float64_t tau = pow(2, sigma);
mmd->add_kernel(new CGaussianKernel(10, tau));
}
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD);

mmd->set_train_test_mode(true);
mmd->set_train_test_ratio(train_test_ratio);
EXPECT_NO_THROW(mmd->select_kernel());
}

#ifdef USE_GPL_SHOGUN
TEST(KernelSelectionMaxMMD, linear_time_weighted_kernel_streaming)
{
Expand Down Expand Up @@ -289,7 +322,7 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense)
mmd->set_train_test_mode(false);

auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
EXPECT_NEAR(selected_kernel->get_width(), 0.25, 1E-10);
EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10);
}

TEST(KernelSelectionMaxCrossValidation, linear_time_single_kernel_dense)
Expand Down

0 comments on commit 0b519ee

Please sign in to comment.