-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathIterativeMachine.h
135 lines (115 loc) · 2.98 KB
/
IterativeMachine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Shubham Shukla
*/
#ifndef _ITERATIVEMACHINE_H__
#define _ITERATIVEMACHINE_H__
#include <shogun/lib/config.h>
#include <shogun/base/progress.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/common.h>
#include <shogun/machine/LinearMachine.h>
namespace shogun
{
class Features;
class Labels;
/** @brief Mix-in class that implements an iterative model
* whose training can be prematurely stopped, and in particular be
* resumed, anytime.
*/
template <class T>
class IterativeMachine : public T
{
public:
/** Default constructor */
IterativeMachine() : T()
{
m_current_iteration = 0;
m_complete = false;
m_continue_features = nullptr;
SG_ADD(
&m_current_iteration, "current_iteration",
"Current Iteration of training");
SG_ADD(
&m_max_iterations, "max_iterations",
"Maximum number of Iterations", ParameterProperties::HYPER);
SG_ADD(
&m_complete, "complete", "Convergence status");
SG_ADD(
&m_continue_features, "continue_features", "Continue Features");
}
~IterativeMachine() override
{
}
/** Returns convergence status */
bool is_complete()
{
return m_complete;
}
bool continue_train() override
{
this->reset_computation_variables();
//this->put("features", m_continue_features);
auto pb = SG_PROGRESS(range(m_max_iterations));
while (m_current_iteration < m_max_iterations && !m_complete)
{
COMPUTATION_CONTROLLERS
iteration();
m_current_iteration++;
pb.print_progress();
}
pb.complete();
if (m_complete)
{
io::info(
"{} converged after {} iterations.", this->get_name(),
m_current_iteration);
this->end_training();
}
else if (!m_complete && m_current_iteration >= m_max_iterations)
{
io::warn(
"{} did not converge after the maximum number of {} "
"iterations.",
this->get_name(), m_current_iteration);
this->end_training();
}
return m_complete;
}
protected:
bool train_machine(std::shared_ptr<Features> data = NULL) override
{
if (data)
{
m_continue_features = data;
}
m_current_iteration = 0;
m_complete = false;
init_model(data);
return continue_train();
}
/** To be overloaded by sublcasses to implement custom single
* iterations of training loop.
*/
virtual void iteration() = 0;
/** To be overloaded in subclasses to initialize the model for training
*/
virtual void init_model(const std::shared_ptr<Features> data = NULL) = 0;
/** Can be overloaded in subclasses to show more information
* and/or clean up states
*/
virtual void end_training()
{
}
/** Stores features to continue training */
std::shared_ptr<Features> m_continue_features;
/** Maximum Iterations */
int32_t m_max_iterations;
/** Current iteration of training loop */
int32_t m_current_iteration;
/** Completion status */
bool m_complete;
};
}
#endif