-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
StoppableSGObject.h
122 lines (99 loc) · 3.02 KB
/
StoppableSGObject.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Shubham Shukla
*/
#ifndef __STOPPABLESGOBJECT_H_
#define __STOPPABLESGOBJECT_H_
#include <shogun/base/SGObject.h>
#include <shogun/base/init.h>
#include <condition_variable>
#include <mutex>
namespace shogun
{
#define COMPUTATION_CONTROLLERS \
if (this->cancel_computation()) \
break; \
this->pause_computation();
/**
* Class that abstracts all premature stopping code
*/
class CStoppableSGObject : public CSGObject
{
public:
/** constructor */
CStoppableSGObject();
/** destructor */
virtual ~CStoppableSGObject();
#ifndef SWIG
/** @return whether the algorithm needs to be stopped */
SG_FORCED_INLINE bool cancel_computation() const
{
/* Execute the callback, if present*/
return (m_callback) ? (m_cancel_computation.load() || m_callback())
: m_cancel_computation.load();
}
#endif
#ifndef SWIG
/** Pause the algorithm if the flag is set */
SG_FORCED_INLINE void pause_computation()
{
if (m_pause_computation_flag.load())
{
std::unique_lock<std::mutex> lck(m_mutex);
while (m_pause_computation_flag.load())
m_pause_computation.wait(lck);
}
}
#endif
#ifndef SWIG
/** Resume current computation (sets the flag) */
SG_FORCED_INLINE void resume_computation()
{
std::unique_lock<std::mutex> lck(m_mutex);
m_pause_computation_flag = false;
m_pause_computation.notify_all();
}
#endif
/**
* Set an additional stopping condition
* @param callback method that implements an additional stopping
* condition
*/
void set_callback(std::function<bool()> callback);
virtual const char* get_name() const
{
return "StoppableSGObject";
}
protected:
/** connect the machine instance to the signal handler */
rxcpp::subscription connect_to_signal_handler();
/** reset the computation variables */
void reset_computation_variables();
/** sets cancel computation flag */
void on_next();
/** The action which will be done when the user decides to
* premature stop the CMachine execution */
virtual void on_next_impl();
/** sets pause computation flag and resumes after action is complete */
void on_pause();
/** The action which will be done when the user decides to
* pause the CMachine execution */
virtual void on_pause_impl();
/** These actions which will be done when the user decides to
* return to prompt and terminate the program execution */
void on_complete();
virtual void on_complete_impl();
protected:
/** Cancel computation */
std::atomic<bool> m_cancel_computation;
/** Pause computation flag */
std::atomic<bool> m_pause_computation_flag;
/** Conditional variable to make threads wait */
std::condition_variable m_pause_computation;
/** Mutex used to pause threads */
std::mutex m_mutex;
std::function<bool(void)> m_callback;
};
}
#endif