Skip to content

Commit

Permalink
full and blockwise train/test data fetchers fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent 5c35b41 commit ea78548
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 5 deletions.
Expand Up @@ -34,7 +34,8 @@ using namespace shogun;
using namespace internal;

BlockwiseDetails::BlockwiseDetails() : m_blocksize(0), m_num_blocks_per_burst(1),
m_max_num_samples_per_burst(0), m_next_block_index(0), m_total_num_blocks(0)
m_max_num_samples_per_burst(0), m_next_block_index(0), m_total_num_blocks(0),
m_full_data(true)
{
}

Expand Down
3 changes: 3 additions & 0 deletions src/shogun/statistical_testing/internals/BlockwiseDetails.h
Expand Up @@ -86,6 +86,9 @@ class BlockwiseDetails

/** Total number of blocks to be fetched. Set by data fetchers */
index_t m_total_num_blocks;

/** Whether the block should consist of full data (i.e. no block at all) */
bool m_full_data;
};

}
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/statistical_testing/internals/DataFetcher.cpp
Expand Up @@ -52,6 +52,7 @@ void DataFetcher::set_train_test_ratio(float64_t train_test_ratio)
m_num_samples=m_train_test_details.get_total_num_samples();
index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1);
m_train_test_details.set_num_training_samples(num_training_samples);
SG_SINFO("Must set the train/test mode by calling set_train_mode(True/False)!\n");
}

float64_t DataFetcher::get_train_test_ratio() const
Expand Down Expand Up @@ -85,9 +86,7 @@ void DataFetcher::set_train_mode(bool train_mode)
SGVector<index_t> inds(m_num_samples);
std::iota(inds.data(), inds.data()+inds.size(), start_index);
if (train_test_subset_used)
{
m_samples->remove_subset();
}
m_samples->add_subset(inds);
train_test_subset_used=true;
}
Expand Down Expand Up @@ -120,9 +119,9 @@ void DataFetcher::use_fold(index_t idx)
void DataFetcher::start()
{
REQUIRE(m_num_samples>0, "Number of samples is 0!\n");
if (m_block_details.m_blocksize==0 || m_block_details.m_blocksize>m_num_samples)
if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples)
{
SG_SINFO("Block details invalid! Fetching entire data (%d samples)!\n", m_num_samples);
SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples);
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize;
Expand Down Expand Up @@ -165,5 +164,6 @@ const index_t DataFetcher::get_num_samples() const

BlockwiseDetails& DataFetcher::fetch_blockwise()
{
m_block_details.m_full_data=false;
return m_block_details;
}
112 changes: 112 additions & 0 deletions tests/unit/statistical_testing/internals/DataManager_unittest.cc
Expand Up @@ -521,6 +521,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats)

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();

// test data
mgr.set_train_mode(false);
Expand All @@ -540,6 +541,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats)

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();

// full data
mgr.set_train_test_ratio(0);
Expand All @@ -560,5 +562,115 @@ TEST(DataManager, train_data_two_distributions_normal_feats)

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();
}

TEST(DataManager, train_data_two_distributions_normal_feats_blockwise)
{
const index_t dim=3;
const index_t num_vec=8;
const index_t blocksize=2;
const index_t num_blocks_per_burst=2;
const index_t num_distributions=2;
const index_t train_test_ratio=3;

SGMatrix<float64_t> data_p(dim, num_vec);
std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0);

SGMatrix<float64_t> data_q(dim, num_vec);
std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec);

using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
auto feats_q=new feat_type(data_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=feats_p;
mgr.samples_at(1)=feats_q;
mgr.set_blocksize(blocksize);
mgr.set_num_blocks_per_burst(num_blocks_per_burst);

mgr.set_train_test_ratio(train_test_ratio);

// train data
mgr.set_train_mode(true);
mgr.start();

auto next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

auto total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec*train_test_ratio/(train_test_ratio+1));
mgr.end();

// test data
mgr.set_train_mode(false);
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec/(train_test_ratio+1));
mgr.end();

// full data
mgr.set_train_test_ratio(0);
mgr.set_train_mode(false);
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec);
mgr.end();
}

0 comments on commit ea78548

Please sign in to comment.