Skip to content

Commit

Permalink
updated data manager to work with blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 4, 2016
1 parent cadd06a commit b84db75
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -28,8 +28,11 @@
* either expressed or implied, of the Shogun Development Team.
*/

#include <memory>
#include <shogun/io/SGIO.h>
#include <shogun/features/Features.h>
#include <shogun/features/DenseFeatures.h> // TODO remove
#include <shogun/statistical_testing/internals/Block.h>
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/NextSamples.h>
#include <shogun/statistical_testing/internals/DataFetcher.h>
Expand Down Expand Up @@ -131,8 +134,8 @@ void DataManager::set_num_blocks_per_burst(index_t num_blocks_per_burst)
index_t max_num_blocks_per_burst=get_num_samples()/blocksize;
if (num_blocks_per_burst>max_num_blocks_per_burst)
{
SG_SPRINT("There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
SG_SPRINT("Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
SG_SINFO("There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
SG_SINFO("Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
num_blocks_per_burst=max_num_blocks_per_burst;
}

Expand All @@ -158,7 +161,7 @@ CFeatures* DataManager::samples_at(size_t i) const
"Value of i (%d) should be between 0 and %d, inclusive!",
i, fetchers.size()-1);
SG_SDEBUG("Leaving!\n");
return fetchers[i]->m_samples.get();
return fetchers[i]->m_samples;
}

index_t& DataManager::num_samples_at(size_t i)
Expand Down Expand Up @@ -202,39 +205,28 @@ void DataManager::start()
NextSamples DataManager::next()
{
SG_SDEBUG("Entering!\n");

// sets the number of feature objects (number of distributions)
NextSamples next_samples(fetchers.size());

// fetch a number of blocks (per burst) from each distribution
for (size_t i=0; i<fetchers.size(); ++i)
{
auto feats=fetchers[i]->next();
if (feats!=nullptr)
{
ASSERT(feats->ref_count()==0);

auto blocksize=fetchers[i]->m_block_details.m_blocksize;
auto num_blocks_curr_burst=feats->get_num_vectors()/blocksize;

// use same number of blocks from all the distributions
if (next_samples.m_num_blocks==0)
next_samples.m_num_blocks=num_blocks_curr_burst;
else
ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst);

// TODO remove
SG_SPRINT("blocksize is %d!\n", blocksize);
SG_SPRINT("number of blocks to be fetched is %d!\n", num_blocks_curr_burst);

// next samples are gonna hold one feats obj per block for this burst
next_samples[i].resize(num_blocks_curr_burst);
for (auto j=0; j<num_blocks_curr_burst; ++j)
{
// create a shallow copy and subset each block separately
auto block=FeaturesUtil::create_shallow_copy(feats.get());
SGVector<index_t> inds(blocksize);
std::iota(inds.vector, inds.vector + inds.vlen, j*blocksize);
block->add_subset(inds);
SG_REF(block);
next_samples[i][j]=std::shared_ptr<CFeatures>(block, [](CFeatures* ptr) { SG_UNREF(ptr); });
// TODO remove
SG_SPRINT("Number of samples fetched is %d!\n", next_samples[i][j]->get_num_vectors());
std::for_each(inds.vector, inds.vector+inds.vlen, [&blocksize](index_t& val) { val+=blocksize; });
}
next_samples[i]=Block::create_blocks(feats, num_blocks_curr_burst, blocksize);
}
}
SG_SDEBUG("Leaving!\n");
Expand Down

0 comments on commit b84db75

Please sign in to comment.