Skip to content

Commit

Permalink
updated data fetchers to return naked pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent 21bb60b commit 9340b9f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 71 deletions.
56 changes: 27 additions & 29 deletions src/shogun/statistical_testing/internals/DataFetcher.cpp
Expand Up @@ -19,24 +19,26 @@
#include <algorithm>
#include <shogun/features/Features.h>
#include <shogun/statistical_testing/internals/DataFetcher.h>

#include <shogun/statistical_testing/internals/FeaturesUtil.h>

using namespace shogun;
using namespace internal;

DataFetcher::DataFetcher() : m_num_samples(0)
DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr)
{
}

DataFetcher::DataFetcher(CFeatures* samples)
DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples)
{
SG_REF(samples);
m_samples = std::shared_ptr<CFeatures>(samples, [](CFeatures* ptr) { SG_UNREF(ptr); });
m_num_samples = m_samples->get_num_vectors();
REQUIRE(m_samples!=nullptr, "Samples cannot be null!\n");
SG_REF(m_samples);
m_num_samples=m_samples->get_num_vectors();
}

DataFetcher::~DataFetcher()
{
end();
SG_UNREF(m_samples);
}

const char* DataFetcher::get_name() const
Expand All @@ -46,47 +48,43 @@ const char* DataFetcher::get_name() const

void DataFetcher::start()
{
if (m_block_details.m_blocksize == 0)
REQUIRE(m_num_samples>0, "Number of samples is 0!\n");
if (m_block_details.m_blocksize==0)
{
SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples);
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks = m_num_samples / m_block_details.m_blocksize;
m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize;
reset();
}

std::shared_ptr<CFeatures> DataFetcher::next()
CFeatures* DataFetcher::next()
{
auto num_more_samples = m_num_samples - m_block_details.m_next_block_index * m_block_details.m_blocksize;
if (num_more_samples > 0)
CFeatures* next_samples=nullptr;
// figure out how many samples to fetch in this burst
auto num_already_fetched=m_block_details.m_next_block_index*m_block_details.m_blocksize;
auto num_more_samples=m_num_samples-num_already_fetched;
if (num_more_samples>0)
{
auto num_samples_this_burst = m_block_details.m_max_num_samples_per_burst;
if (num_samples_this_burst > num_more_samples)
{
num_samples_this_burst = num_more_samples;
}
if (num_samples_this_burst < m_num_samples)
{
m_samples->remove_subset();
SGVector<index_t> inds(num_samples_this_burst);
std::iota(inds.vector, inds.vector + inds.vlen, m_block_details.m_next_block_index * m_block_details.m_blocksize);
m_samples->add_subset(inds);
}
auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples);
// create a shallow copy and add proper index subset
next_samples=FeaturesUtil::create_shallow_copy(m_samples);
SGVector<index_t> inds(num_samples_this_burst);
std::iota(inds.vector, inds.vector+inds.vlen, num_already_fetched);
next_samples->add_subset(inds);

m_block_details.m_next_block_index += m_block_details.m_num_blocks_per_burst;
return m_samples;
m_block_details.m_next_block_index+=m_block_details.m_num_blocks_per_burst;
}
return nullptr;
return next_samples;
}

void DataFetcher::reset()
{
m_block_details.m_next_block_index = 0;
m_samples->remove_all_subsets();
m_block_details.m_next_block_index=0;
}

void DataFetcher::end()
{
m_samples->remove_all_subsets();
}

const index_t DataFetcher::get_num_samples() const
Expand Down
5 changes: 2 additions & 3 deletions src/shogun/statistical_testing/internals/DataFetcher.h
Expand Up @@ -28,7 +28,6 @@
* either expressed or implied, of the Shogun Development Team.
*/

#include <memory>
#include <shogun/lib/common.h>
#include <shogun/statistical_testing/internals/BlockwiseDetails.h>

Expand All @@ -53,7 +52,7 @@ class DataFetcher
DataFetcher(CFeatures* samples);
virtual ~DataFetcher();
virtual void start();
virtual std::shared_ptr<CFeatures> next();
virtual CFeatures* next();
virtual void reset();
virtual void end();
const index_t get_num_samples() const;
Expand All @@ -64,7 +63,7 @@ class DataFetcher
BlockwiseDetails m_block_details;
index_t m_num_samples;
private:
std::shared_ptr<CFeatures> m_samples;
CFeatures* m_samples;
};

}
Expand Down
51 changes: 26 additions & 25 deletions src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp
Expand Up @@ -17,19 +17,22 @@
*/

#include <algorithm>
#include <shogun/io/SGIO.h>
#include <shogun/features/Features.h>
#include <shogun/features/streaming/StreamingFeatures.h>
#include <shogun/statistical_testing/internals/StreamingDataFetcher.h>

#include <shogun/statistical_testing/internals/BlockwiseDetails.h>

using namespace shogun;
using namespace internal;

StreamingDataFetcher::StreamingDataFetcher(CStreamingFeatures* samples) : DataFetcher(), parser_running(false)
StreamingDataFetcher::StreamingDataFetcher(CStreamingFeatures* samples)
: DataFetcher(), parser_running(false)
{
REQUIRE(samples!=nullptr, "Samples cannot be null!\n");
SG_REF(samples);
m_samples = std::shared_ptr<CStreamingFeatures>(samples, [](CFeatures* ptr) { SG_UNREF(ptr); });
m_num_samples = 0;
m_samples=std::shared_ptr<CStreamingFeatures>(samples, [](CStreamingFeatures* ptr) { SG_UNREF(ptr); });
m_num_samples=0;
}

StreamingDataFetcher::~StreamingDataFetcher()
Expand All @@ -44,47 +47,45 @@ const char* StreamingDataFetcher::get_name() const

void StreamingDataFetcher::set_num_samples(index_t num_samples)
{
m_num_samples = num_samples;
m_num_samples=num_samples;
}

void StreamingDataFetcher::start()
{
ASSERT(m_num_samples);
if (m_block_details.m_blocksize == 0)
REQUIRE(m_num_samples>0, "Number of samples is not set! It is MANDATORY for streaming features!\n");
if (m_block_details.m_blocksize==0)
{
SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples);
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks = m_num_samples / m_block_details.m_blocksize;
m_block_details.m_next_block_index = 0;
m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize;
m_block_details.m_next_block_index=0;
if (!parser_running)
{
m_samples->start_parser();
parser_running = true;
parser_running=true;
// TODO check if resetting the stream is required
}
}

std::shared_ptr<CFeatures> StreamingDataFetcher::next()
CFeatures* StreamingDataFetcher::next()
{
auto num_more_samples = m_num_samples - m_block_details.m_next_block_index * m_block_details.m_blocksize;
if (num_more_samples > 0)
CFeatures* next_samples=nullptr;
// figure out how many samples to fetch in this burst
auto num_already_fetched=m_block_details.m_next_block_index*m_block_details.m_blocksize;
auto num_more_samples=m_num_samples-num_already_fetched;
if (num_more_samples>0)
{
auto num_samples_this_burst = m_block_details.m_max_num_samples_per_burst;
if (num_samples_this_burst > num_more_samples)
{
num_samples_this_burst = num_more_samples;
}

CFeatures* streamed = m_samples->get_streamed_features(num_samples_this_burst);
m_block_details.m_next_block_index += m_block_details.m_num_blocks_per_burst;
return std::shared_ptr<CFeatures>(streamed, [](CFeatures* ptr) { SG_UNREF(ptr); });
auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples);
next_samples=m_samples->get_streamed_features(num_samples_this_burst);
m_block_details.m_next_block_index+=m_block_details.m_num_blocks_per_burst;
}
return nullptr;
return next_samples;
}

void StreamingDataFetcher::reset()
{
m_block_details.m_next_block_index = 0;
m_block_details.m_next_block_index=0;
m_samples->reset_stream();
}

Expand All @@ -93,6 +94,6 @@ void StreamingDataFetcher::end()
if (parser_running)
{
m_samples->end_parser();
parser_running = false;
parser_running=false;
}
}
Expand Up @@ -19,7 +19,6 @@
#include <memory>
#include <shogun/lib/common.h>
#include <shogun/statistical_testing/internals/DataFetcher.h>
#include <shogun/statistical_testing/internals/BlockwiseDetails.h>

#ifndef STREMING_DATA_FETCHER_H__
#define STREMING_DATA_FETCHER_H__
Expand All @@ -41,7 +40,7 @@ class StreamingDataFetcher : public DataFetcher
StreamingDataFetcher(CStreamingFeatures* samples);
virtual ~StreamingDataFetcher() override;
virtual void start() override;
virtual std::shared_ptr<CFeatures> next() override;
virtual CFeatures* next() override;
virtual void reset() override;
virtual void end() override;
void set_num_samples(index_t num_samples);
Expand Down
15 changes: 11 additions & 4 deletions tests/unit/statistical_testing/internals/DataFetcher_unittest.cc
Expand Up @@ -56,9 +56,11 @@ TEST(DataFetcher, full_data)
auto curr=fetcher.next();
ASSERT_TRUE(curr!=nullptr);

auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);

SG_UNREF(curr);

curr=fetcher.next();
ASSERT_TRUE(curr==nullptr);
fetcher.end();
Expand Down Expand Up @@ -88,9 +90,11 @@ TEST(DataFetcher, block_data)
ASSERT_TRUE(curr!=nullptr);
while (curr!=nullptr)
{
auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);
ASSERT_TRUE(tmp->get_num_vectors()==blocksize*num_blocks_per_burst);

SG_UNREF(curr);
curr=fetcher.next();
}
fetcher.end();
Expand All @@ -115,9 +119,11 @@ TEST(DataFetcher, reset_functionality)
auto curr=fetcher.next();
ASSERT_TRUE(curr!=nullptr);

auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);

SG_UNREF(curr);

curr=fetcher.next();
ASSERT_TRUE(curr==nullptr);

Expand All @@ -131,9 +137,10 @@ TEST(DataFetcher, reset_functionality)
ASSERT_TRUE(curr!=nullptr);
while (curr!=nullptr)
{
tmp=dynamic_cast<feat_type*>(curr.get());
tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);
ASSERT_TRUE(tmp->get_num_vectors()==blocksize*num_blocks_per_burst);
SG_UNREF(curr);
curr=fetcher.next();
}
fetcher.end();
Expand Down
Expand Up @@ -51,7 +51,6 @@ TEST(StreamingDataFetcher, full_data)
using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
CStreamingFeatures *streaming_p = new CStreamingDenseFeatures<float64_t>(feats_p);
SG_REF(streaming_p); // TODO check why this refcount is required

StreamingDataFetcher fetcher(streaming_p);
fetcher.set_num_samples(num_vec);
Expand All @@ -60,9 +59,11 @@ TEST(StreamingDataFetcher, full_data)
auto curr=fetcher.next();
ASSERT_TRUE(curr!=nullptr);

auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);

SG_UNREF(curr);

curr=fetcher.next();
ASSERT_TRUE(curr==nullptr);
fetcher.end();
Expand All @@ -81,7 +82,6 @@ TEST(StreamingDataFetcher, block_data)
using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
CStreamingFeatures *streaming_p = new CStreamingDenseFeatures<float64_t>(feats_p);
SG_REF(streaming_p); // TODO check why this refcount is required

StreamingDataFetcher fetcher(streaming_p);
fetcher.set_num_samples(num_vec);
Expand All @@ -95,15 +95,16 @@ TEST(StreamingDataFetcher, block_data)
ASSERT_TRUE(curr!=nullptr);
while (curr!=nullptr)
{
auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);
ASSERT_TRUE(tmp->get_num_vectors()==blocksize*num_blocks_per_burst);
SG_UNREF(curr);
curr=fetcher.next();
}
fetcher.end();
}

TEST(StreamingDataFetcher, reset_functionality)
TEST(StreamingDataFetcher, DISABLED_reset_functionality)
{
const index_t dim=3;
const index_t num_vec=8;
Expand All @@ -116,7 +117,6 @@ TEST(StreamingDataFetcher, reset_functionality)
using feat_type=CDenseFeatures<float64_t>;
auto feats_p=new feat_type(data_p);
CStreamingFeatures *streaming_p = new CStreamingDenseFeatures<float64_t>(feats_p);
SG_REF(streaming_p); // TODO check why this refcount is required

StreamingDataFetcher fetcher(streaming_p);
fetcher.set_num_samples(num_vec);
Expand All @@ -125,9 +125,11 @@ TEST(StreamingDataFetcher, reset_functionality)
auto curr=fetcher.next();
ASSERT_TRUE(curr!=nullptr);

auto tmp=dynamic_cast<feat_type*>(curr.get());
auto tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);

SG_UNREF(curr);

curr=fetcher.next();
ASSERT_TRUE(curr==nullptr);

Expand All @@ -141,9 +143,10 @@ TEST(StreamingDataFetcher, reset_functionality)
ASSERT_TRUE(curr!=nullptr);
while (curr!=nullptr)
{
tmp=dynamic_cast<feat_type*>(curr.get());
tmp=dynamic_cast<feat_type*>(curr);
ASSERT_TRUE(tmp!=nullptr);
ASSERT_TRUE(tmp->get_num_vectors()==blocksize*num_blocks_per_burst);
SG_UNREF(curr);
curr=fetcher.next();
}
fetcher.end();
Expand Down

0 comments on commit 9340b9f

Please sign in to comment.