Skip to content

Commit

Permalink
full and blockwise train/test streaming data fetchers added
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent c613b04 commit 6beb360
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/shogun/statistical_testing/internals/DataFetcher.cpp
Expand Up @@ -50,6 +50,7 @@ const char* DataFetcher::get_name() const
void DataFetcher::set_train_test_ratio(float64_t train_test_ratio)
{
m_num_samples=m_train_test_details.get_total_num_samples();
REQUIRE(m_num_samples>0, "Number of samples is not set!\n");
index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1);
m_train_test_details.set_num_training_samples(num_training_samples);
SG_SINFO("Must set the train/test mode by calling set_train_mode(True/False)!\n");
Expand All @@ -68,10 +69,12 @@ void DataFetcher::set_train_mode(bool train_mode)
m_num_samples=m_train_test_details.get_num_training_samples();
if (m_num_samples==0)
SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n");
SG_SINFO("Using %d number of samples for training!\n", m_num_samples);
}
else
{
m_num_samples=m_train_test_details.get_num_test_samples();
SG_SINFO("Using %d number of samples for testing!\n", m_num_samples);
start_index=m_train_test_details.get_num_training_samples();
if (start_index==0)
{
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/statistical_testing/internals/DataFetcher.h
Expand Up @@ -52,9 +52,9 @@ class DataFetcher
public:
DataFetcher(CFeatures* samples);
virtual ~DataFetcher();
void set_train_test_ratio(float64_t train_test_ratio);
virtual void set_train_test_ratio(float64_t train_test_ratio);
float64_t get_train_test_ratio() const;
void set_train_mode(bool train_mode);
virtual void set_train_mode(bool train_mode);
void set_xvalidation_mode(bool xvalidation_mode);
index_t get_num_folds() const;
void use_fold(index_t idx);
Expand Down
29 changes: 26 additions & 3 deletions src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp
Expand Up @@ -48,14 +48,38 @@ const char* StreamingDataFetcher::get_name() const
void StreamingDataFetcher::set_num_samples(index_t num_samples)
{
m_num_samples=num_samples;
m_train_test_details.set_total_num_samples(m_num_samples);
}

void StreamingDataFetcher::set_train_test_ratio(float64_t train_test_ratio)
{
if (m_train_test_details.get_total_num_samples()==0)
m_train_test_details.set_total_num_samples(m_num_samples);
DataFetcher::set_train_test_ratio(train_test_ratio);
}

void StreamingDataFetcher::set_train_mode(bool train_mode)
{
if (train_mode)
{
m_num_samples=m_train_test_details.get_num_training_samples();
if (m_num_samples==0)
SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n");
SG_SINFO("Using %d number of samples for training!\n", m_num_samples);
}
else
{
m_num_samples=m_train_test_details.get_num_test_samples();
SG_SINFO("Using %d number of samples for testing!\n", m_num_samples);
}
}

void StreamingDataFetcher::start()
{
REQUIRE(m_num_samples>0, "Number of samples is not set! It is MANDATORY for streaming features!\n");
if (m_block_details.m_blocksize==0)
if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples)
{
SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples);
SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples);
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize;
Expand All @@ -64,7 +88,6 @@ void StreamingDataFetcher::start()
{
m_samples->start_parser();
parser_running=true;
// TODO check if resetting the stream is required
}
}

Expand Down
Expand Up @@ -39,6 +39,8 @@ class StreamingDataFetcher : public DataFetcher
public:
StreamingDataFetcher(CStreamingFeatures* samples);
virtual ~StreamingDataFetcher() override;
virtual void set_train_test_ratio(float64_t train_test_ratio) override;
virtual void set_train_mode(bool train_mode) override;
virtual void start() override;
virtual CFeatures* next() override;
virtual void reset() override;
Expand Down
201 changes: 201 additions & 0 deletions tests/unit/statistical_testing/internals/DataManager_unittest.cc
Expand Up @@ -674,3 +674,204 @@ TEST(DataManager, train_data_two_distributions_normal_feats_blockwise)
ASSERT_TRUE(total==num_vec);
mgr.end();
}

TEST(DataManager, train_data_two_distributions_streaming_feats)
{
const index_t dim=3;
const index_t num_vec=8;
const index_t num_distributions=2;
const index_t train_test_ratio=3;

SGMatrix<float64_t> data_p(dim, num_vec);
std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0);

SGMatrix<float64_t> data_q(dim, num_vec);
std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec);

using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
auto feats_q=new feat_type(data_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=new CStreamingDenseFeatures<float64_t>(feats_p);
mgr.samples_at(1)=new CStreamingDenseFeatures<float64_t>(feats_q);
mgr.num_samples_at(0)=num_vec;
mgr.num_samples_at(1)=num_vec;

// training data
mgr.set_train_test_ratio(train_test_ratio);
mgr.set_train_mode(true);
mgr.start();

auto next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());
ASSERT_TRUE(next_burst.num_blocks()==1);

auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][0].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][0].get());

ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1));
ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1));

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();

// test data
mgr.set_train_mode(false);
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());
ASSERT_TRUE(next_burst.num_blocks()==1);

tmp_p=dynamic_cast<feat_type*>(next_burst[0][0].get());
tmp_q=dynamic_cast<feat_type*>(next_burst[1][0].get());

ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec/(train_test_ratio+1));
ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec/(train_test_ratio+1));

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();

// full data
mgr.set_train_test_ratio(0);
mgr.set_train_mode(false);
mgr.reset();
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());
ASSERT_TRUE(next_burst.num_blocks()==1);

tmp_p=dynamic_cast<feat_type*>(next_burst[0][0].get());
tmp_q=dynamic_cast<feat_type*>(next_burst[1][0].get());

ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec);
ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec);

next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();
}

TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise)
{
const index_t dim=3;
const index_t num_vec=8;
const index_t blocksize=2;
const index_t num_blocks_per_burst=2;
const index_t num_distributions=2;
const index_t train_test_ratio=3;

SGMatrix<float64_t> data_p(dim, num_vec);
std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0);

SGMatrix<float64_t> data_q(dim, num_vec);
std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec);

using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
auto feats_q=new feat_type(data_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=new CStreamingDenseFeatures<float64_t>(feats_p);
mgr.samples_at(1)=new CStreamingDenseFeatures<float64_t>(feats_q);
mgr.num_samples_at(0)=num_vec;
mgr.num_samples_at(1)=num_vec;

mgr.set_blocksize(blocksize);
mgr.set_num_blocks_per_burst(num_blocks_per_burst);

mgr.set_train_test_ratio(train_test_ratio);

// train data
mgr.set_train_mode(true);
mgr.start();

auto next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

auto total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec*train_test_ratio/(train_test_ratio+1));
mgr.end();

// test data
mgr.set_train_mode(false);
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec/(train_test_ratio+1));
mgr.end();

// full data
mgr.set_train_test_ratio(0);
mgr.set_train_mode(false);
mgr.reset();
mgr.start();

next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());

total=0;

while (!next_burst.empty())
{
ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp_p=dynamic_cast<feat_type*>(next_burst[0][i].get());
auto tmp_q=dynamic_cast<feat_type*>(next_burst[1][i].get());
ASSERT_TRUE(tmp_p!=nullptr);
ASSERT_TRUE(tmp_q!=nullptr);
ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2);
ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2);
total+=tmp_p->get_num_vectors();
}
next_burst=mgr.next();
}
ASSERT_TRUE(total==num_vec);
mgr.end();
}

0 comments on commit 6beb360

Please sign in to comment.