-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
DataFetcher.cpp
100 lines (86 loc) · 2.56 KB
/
DataFetcher.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*
* Restructuring Shogun's statistical hypothesis testing framework.
* Copyright (C) 2016 Soumyajit De
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <algorithm>
#include <shogun/features/Features.h>
#include <shogun/statistics/experimental/internals/DataFetcher.h>
using namespace shogun;
using namespace internal;
DataFetcher::DataFetcher() : m_num_samples(0)
{
}
DataFetcher::DataFetcher(CFeatures* samples)
{
SG_REF(samples);
m_samples = std::shared_ptr<CFeatures>(samples, [](CFeatures* ptr) { SG_UNREF(ptr); });
m_num_samples = m_samples->get_num_vectors();
}
DataFetcher::~DataFetcher()
{
}
const char* DataFetcher::get_name() const
{
return "DataFetcher";
}
void DataFetcher::start()
{
if (m_block_details.m_blocksize == 0)
{
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks = m_num_samples / m_block_details.m_blocksize;
reset();
}
std::shared_ptr<CFeatures> DataFetcher::next()
{
auto num_more_samples = m_num_samples - m_block_details.m_next_block_index * m_block_details.m_blocksize;
if (num_more_samples > 0)
{
auto num_samples_this_burst = m_block_details.m_max_num_samples_per_burst;
if (num_samples_this_burst > num_more_samples)
{
num_samples_this_burst = num_more_samples;
}
if (num_samples_this_burst < m_num_samples)
{
m_samples->remove_subset();
SGVector<index_t> inds(num_samples_this_burst);
std::iota(inds.vector, inds.vector + inds.vlen, m_block_details.m_next_block_index * m_block_details.m_blocksize);
m_samples->add_subset(inds);
}
m_block_details.m_next_block_index += m_block_details.m_num_blocks_per_burst;
return m_samples;
}
return nullptr;
}
void DataFetcher::reset()
{
m_block_details.m_next_block_index = 0;
m_samples->remove_all_subsets();
}
void DataFetcher::end()
{
m_samples->remove_all_subsets();
}
const index_t DataFetcher::get_num_samples() const
{
return m_num_samples;
}
BlockwiseDetails& DataFetcher::fetch_blockwise()
{
return m_block_details;
}