Skip to content

Commit

Permalink
made the data manager work with shallow copy
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent 7fad933 commit d681ea6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
28 changes: 19 additions & 9 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -33,6 +33,7 @@
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/NextSamples.h>
#include <shogun/statistical_testing/internals/DataFetcher.h>
#include <shogun/statistical_testing/internals/FeaturesUtil.h>
#include <shogun/statistical_testing/internals/DataFetcherFactory.h>

using namespace shogun;
Expand Down Expand Up @@ -128,9 +129,12 @@ void DataManager::set_num_blocks_per_burst(index_t num_blocks_per_burst)
"Blocksizes are not set!\n");

index_t max_num_blocks_per_burst=get_num_samples()/blocksize;
REQUIRE(num_blocks_per_burst<=max_num_blocks_per_burst,
"There can only be %d blocks per burst given the blocksize (%d)!",
max_num_blocks_per_burst, blocksize);
if (num_blocks_per_burst>max_num_blocks_per_burst)
{
SG_SPRINT("There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
SG_SPRINT("Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
num_blocks_per_burst=max_num_blocks_per_burst;
}

for (size_t i=0; i<fetchers.size(); ++i)
fetchers[i]->fetch_blockwise().with_num_blocks_per_burst(num_blocks_per_burst);
Expand Down Expand Up @@ -212,17 +216,23 @@ NextSamples DataManager::next()
else
ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst);

// TODO remove
SG_SPRINT("blocksize is %d!\n", blocksize);
SG_SPRINT("number of blocks to be fetched is %d!\n", num_blocks_curr_burst);

// next samples are gonna hold one feats obj per block for this burst
next_samples[i].resize(num_blocks_curr_burst);
SGVector<index_t> inds(blocksize);
std::iota(inds.vector, inds.vector + inds.vlen, 0);
for (auto j=0; j<num_blocks_curr_burst; ++j)
{
// subset each block and clone it separately
feats->add_subset(inds);
auto block=static_cast<CFeatures*>(feats->clone());
// create a shallow copy and subset each block separately
auto block=FeaturesUtil::create_shallow_copy(feats.get());
SGVector<index_t> inds(blocksize);
std::iota(inds.vector, inds.vector + inds.vlen, j*blocksize);
block->add_subset(inds);
SG_REF(block);
next_samples[i][j]=std::shared_ptr<CFeatures>(block, [](CFeatures* ptr) { SG_UNREF(ptr); });
feats->remove_subset();
// TODO remove
SG_SPRINT("Number of samples fetched is %d!\n", next_samples[i][j]->get_num_vectors());
std::for_each(inds.vector, inds.vector+inds.vlen, [&blocksize](index_t& val) { val+=blocksize; });
}
}
Expand Down
Expand Up @@ -82,7 +82,6 @@ TEST(DataManager, full_data_one_distribution_streaming_feats)

auto feats_p=new CDenseFeatures<float64_t>(data_p);
auto streaming_p=new CStreamingDenseFeatures<float64_t>(feats_p);
SG_REF(streaming_p);

DataManager mgr(num_distributions);
mgr.samples_at(0)=streaming_p;
Expand Down Expand Up @@ -159,8 +158,6 @@ TEST(DataManager, full_data_two_distributions_streaming_feats)
auto feats_q=new feat_type(data_q);
auto streaming_p=new CStreamingDenseFeatures<float64_t>(feats_p);
auto streaming_q=new CStreamingDenseFeatures<float64_t>(feats_q);
SG_REF(streaming_p);
SG_REF(streaming_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=streaming_p;
Expand Down Expand Up @@ -239,7 +236,6 @@ TEST(DataManager, block_data_one_distribution_streaming_feats)

auto feats_p=new CDenseFeatures<float64_t>(data_p);
auto streaming_p=new CStreamingDenseFeatures<float64_t>(feats_p);
SG_REF(streaming_p);

DataManager mgr(num_distributions);
mgr.samples_at(0)=streaming_p;
Expand Down Expand Up @@ -337,8 +333,6 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_equal_blocksize)
auto feats_q=new feat_type(data_q);
auto streaming_p=new CStreamingDenseFeatures<float64_t>(feats_p);
auto streaming_q=new CStreamingDenseFeatures<float64_t>(feats_q);
SG_REF(streaming_p);
SG_REF(streaming_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=streaming_p;
Expand Down Expand Up @@ -451,8 +445,6 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_different_blocksi
auto feats_q=new feat_type(data_q);
auto streaming_p=new CStreamingDenseFeatures<float64_t>(feats_p);
auto streaming_q=new CStreamingDenseFeatures<float64_t>(feats_q);
SG_REF(streaming_p);
SG_REF(streaming_q);

DataManager mgr(num_distributions);
mgr.samples_at(0)=streaming_p;
Expand Down

0 comments on commit d681ea6

Please sign in to comment.