diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index c1e1dbe7333..f878904ae25 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include using namespace shogun; @@ -128,9 +129,12 @@ void DataManager::set_num_blocks_per_burst(index_t num_blocks_per_burst) "Blocksizes are not set!\n"); index_t max_num_blocks_per_burst=get_num_samples()/blocksize; - REQUIRE(num_blocks_per_burst<=max_num_blocks_per_burst, - "There can only be %d blocks per burst given the blocksize (%d)!", - max_num_blocks_per_burst, blocksize); + if (num_blocks_per_burst>max_num_blocks_per_burst) + { + SG_SPRINT("There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize); + SG_SPRINT("Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst); + num_blocks_per_burst=max_num_blocks_per_burst; + } for (size_t i=0; ifetch_blockwise().with_num_blocks_per_burst(num_blocks_per_burst); @@ -212,17 +216,23 @@ NextSamples DataManager::next() else ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst); + // TODO remove + SG_SPRINT("blocksize is %d!\n", blocksize); + SG_SPRINT("number of blocks to be fetched is %d!\n", num_blocks_curr_burst); + // next samples are gonna hold one feats obj per block for this burst next_samples[i].resize(num_blocks_curr_burst); - SGVector inds(blocksize); - std::iota(inds.vector, inds.vector + inds.vlen, 0); for (auto j=0; jadd_subset(inds); - auto block=static_cast(feats->clone()); + // create a shallow copy and subset each block separately + auto block=FeaturesUtil::create_shallow_copy(feats.get()); + SGVector inds(blocksize); + std::iota(inds.vector, inds.vector + inds.vlen, j*blocksize); + block->add_subset(inds); + SG_REF(block); next_samples[i][j]=std::shared_ptr(block, [](CFeatures* ptr) { SG_UNREF(ptr); }); - feats->remove_subset(); + // TODO remove + SG_SPRINT("Number of samples fetched is %d!\n", next_samples[i][j]->get_num_vectors()); std::for_each(inds.vector, inds.vector+inds.vlen, [&blocksize](index_t& val) { val+=blocksize; }); } } diff --git a/tests/unit/statistical_testing/internals/DataManager_unittest.cc b/tests/unit/statistical_testing/internals/DataManager_unittest.cc index 8c4c881ab8d..5e54a167993 100644 --- a/tests/unit/statistical_testing/internals/DataManager_unittest.cc +++ b/tests/unit/statistical_testing/internals/DataManager_unittest.cc @@ -82,7 +82,6 @@ TEST(DataManager, full_data_one_distribution_streaming_feats) auto feats_p=new CDenseFeatures(data_p); auto streaming_p=new CStreamingDenseFeatures(feats_p); - SG_REF(streaming_p); DataManager mgr(num_distributions); mgr.samples_at(0)=streaming_p; @@ -159,8 +158,6 @@ TEST(DataManager, full_data_two_distributions_streaming_feats) auto feats_q=new feat_type(data_q); auto streaming_p=new CStreamingDenseFeatures(feats_p); auto streaming_q=new CStreamingDenseFeatures(feats_q); - SG_REF(streaming_p); - SG_REF(streaming_q); DataManager mgr(num_distributions); mgr.samples_at(0)=streaming_p; @@ -239,7 +236,6 @@ TEST(DataManager, block_data_one_distribution_streaming_feats) auto feats_p=new CDenseFeatures(data_p); auto streaming_p=new CStreamingDenseFeatures(feats_p); - SG_REF(streaming_p); DataManager mgr(num_distributions); mgr.samples_at(0)=streaming_p; @@ -337,8 +333,6 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_equal_blocksize) auto feats_q=new feat_type(data_q); auto streaming_p=new CStreamingDenseFeatures(feats_p); auto streaming_q=new CStreamingDenseFeatures(feats_q); - SG_REF(streaming_p); - SG_REF(streaming_q); DataManager mgr(num_distributions); mgr.samples_at(0)=streaming_p; @@ -451,8 +445,6 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_different_blocksi auto feats_q=new feat_type(data_q); auto streaming_p=new CStreamingDenseFeatures(feats_p); auto streaming_q=new CStreamingDenseFeatures(feats_q); - SG_REF(streaming_p); - SG_REF(streaming_q); DataManager mgr(num_distributions); mgr.samples_at(0)=streaming_p;