From 986f97ee9d25898b8a3f75fa9976cc55305b4fdc Mon Sep 17 00:00:00 2001 From: Shubham Shukla Date: Fri, 11 May 2018 20:47:51 +0530 Subject: [PATCH] CStoppableSGObject class (#4280) * StoppableSGObject class * add base class to swig --- src/interfaces/swig/SGBase.i | 2 + src/shogun/lib/StoppableSGObject.cpp | 74 ++++++++++++++++++ src/shogun/lib/StoppableSGObject.h | 113 +++++++++++++++++++++++++++ src/shogun/machine/Machine.cpp | 19 +---- src/shogun/machine/Machine.h | 84 +------------------- 5 files changed, 194 insertions(+), 98 deletions(-) create mode 100644 src/shogun/lib/StoppableSGObject.cpp create mode 100644 src/shogun/lib/StoppableSGObject.h diff --git a/src/interfaces/swig/SGBase.i b/src/interfaces/swig/SGBase.i index 6dc69290c98..3ddac31bded 100644 --- a/src/interfaces/swig/SGBase.i +++ b/src/interfaces/swig/SGBase.i @@ -109,6 +109,7 @@ public void readExternal(java.io.ObjectInput in) throws java.io.IOException, jav #include #include #include + #include extern void sg_global_print_message(FILE* target, const char* str); extern void sg_global_print_warning(FILE* target, const char* str); @@ -327,6 +328,7 @@ namespace std { %include %include %include +%include #ifdef SWIGPYTHON namespace shogun diff --git a/src/shogun/lib/StoppableSGObject.cpp b/src/shogun/lib/StoppableSGObject.cpp new file mode 100644 index 00000000000..a5b3122a78d --- /dev/null +++ b/src/shogun/lib/StoppableSGObject.cpp @@ -0,0 +1,74 @@ +/* +* This software is distributed under BSD 3-clause license (see LICENSE file). +* +* Authors: Shubham Shukla +*/ + +#include +#include +#include +#include + +using namespace shogun; + +#define COMPUTATION_CONTROLLERS \ + if (cancel_computation()) \ + continue; \ + pause_computation(); + +CStoppableSGObject::CStoppableSGObject() : CSGObject(){}; + +CStoppableSGObject::~CStoppableSGObject(){}; + +void CStoppableSGObject::init_stoppable() +{ + m_cancel_computation = false; + m_pause_computation_flag = false; +} + +rxcpp::subscription CStoppableSGObject::connect_to_signal_handler() +{ + // Subscribe this algorithm to the signal handler + auto subscriber = rxcpp::make_subscriber( + [this](int i) { + if (i == SG_PAUSE_COMP) + this->on_pause(); + else + this->on_next(); + }, + [this]() { this->on_complete(); }); + return get_global_signal()->get_observable()->subscribe(subscriber); +} + +void CStoppableSGObject::reset_computation_variables() +{ + m_cancel_computation = false; + m_pause_computation_flag = false; +} + +void CStoppableSGObject::on_next() +{ + m_cancel_computation.store(true); + on_next_impl(); +} + +void CStoppableSGObject::on_pause() +{ + m_pause_computation_flag.store(true); + on_pause_impl(); + resume_computation(); +} + +void CStoppableSGObject::on_complete() +{ + on_complete_impl(); +} +void CStoppableSGObject::on_next_impl() +{ +} +void CStoppableSGObject::on_pause_impl() +{ +} +void CStoppableSGObject::on_complete_impl() +{ +} diff --git a/src/shogun/lib/StoppableSGObject.h b/src/shogun/lib/StoppableSGObject.h new file mode 100644 index 00000000000..4babe45ebd6 --- /dev/null +++ b/src/shogun/lib/StoppableSGObject.h @@ -0,0 +1,113 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Shubham Shukla + */ + +#ifndef __STOPPABLESGOBJECT_H_ +#define __STOPPABLESGOBJECT_H_ + +#include +#include + +#include +#include + +namespace shogun +{ +#define COMPUTATION_CONTROLLERS \ + if (cancel_computation()) \ + continue; \ + pause_computation(); + + /** + * Class that abstracts all premature stopping code + */ + class CStoppableSGObject : public CSGObject + { + public: + /** constructor */ + CStoppableSGObject(); + + /** destructor */ + virtual ~CStoppableSGObject(); + + /** init flags to false as default */ + void init_stoppable(); + +#ifndef SWIG + /** @return whether the algorithm needs to be stopped */ + SG_FORCED_INLINE bool cancel_computation() const + { + return m_cancel_computation.load(); + } +#endif + +#ifndef SWIG + /** Pause the algorithm if the flag is set */ + SG_FORCED_INLINE void pause_computation() + { + if (m_pause_computation_flag.load()) + { + std::unique_lock lck(m_mutex); + while (m_pause_computation_flag.load()) + m_pause_computation.wait(lck); + } + } +#endif + +#ifndef SWIG + /** Resume current computation (sets the flag) */ + SG_FORCED_INLINE void resume_computation() + { + std::unique_lock lck(m_mutex); + m_pause_computation_flag = false; + m_pause_computation.notify_all(); + } +#endif + + virtual const char* get_name() const + { + return "StoppableSGObject"; + } + + protected: + /** connect the machine instance to the signal handler */ + rxcpp::subscription connect_to_signal_handler(); + + /** reset the computation variables */ + void reset_computation_variables(); + + /** sets cancel computation flag */ + void on_next(); + + /** The action which will be done when the user decides to + * premature stop the CMachine execution */ + virtual void on_next_impl(); + + /** sets pause computation flag and resumes after action is complete */ + void on_pause(); + + /** The action which will be done when the user decides to + * pause the CMachine execution */ + virtual void on_pause_impl(); + + /** These actions which will be done when the user decides to + * return to prompt and terminate the program execution */ + void on_complete(); + virtual void on_complete_impl(); + + /** Cancel computation */ + std::atomic m_cancel_computation; + + /** Pause computation flag */ + std::atomic m_pause_computation_flag; + + /** Conditional variable to make threads wait */ + std::condition_variable m_pause_computation; + + /** Mutex used to pause threads */ + std::mutex m_mutex; + }; +} +#endif diff --git a/src/shogun/machine/Machine.cpp b/src/shogun/machine/Machine.cpp index b5a53ec3388..55cbd995860 100644 --- a/src/shogun/machine/Machine.cpp +++ b/src/shogun/machine/Machine.cpp @@ -14,9 +14,10 @@ using namespace shogun; CMachine::CMachine() - : CSGObject(), m_max_train_time(0), m_labels(NULL), m_solver_type(ST_AUTO), - m_cancel_computation(false), m_pause_computation_flag(false) + : CStoppableSGObject(), m_max_train_time(0), m_labels(NULL), + m_solver_type(ST_AUTO) { + init_stoppable(); m_data_locked=false; m_store_model_features=false; @@ -273,17 +274,3 @@ CLatentLabels* CMachine::apply_locked_latent(SGVector indices) "for %s\n", get_name()); return NULL; } - -rxcpp::subscription CMachine::connect_to_signal_handler() -{ - // Subscribe this algorithm to the signal handler - auto subscriber = rxcpp::make_subscriber( - [this](int i) { - if (i == SG_PAUSE_COMP) - this->on_pause(); - else - this->on_next(); - }, - [this]() { this->on_complete(); }); - return get_global_signal()->get_observable()->subscribe(subscriber); -} diff --git a/src/shogun/machine/Machine.h b/src/shogun/machine/Machine.h index fa699b76d90..eb0188debd1 100644 --- a/src/shogun/machine/Machine.h +++ b/src/shogun/machine/Machine.h @@ -11,7 +11,6 @@ #ifndef _MACHINE_H__ #define _MACHINE_H__ -#include #include #include #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include @@ -125,11 +125,6 @@ enum EProblemType */ \ virtual EProblemType get_machine_problem_type() const { return PT; } -#define COMPUTATION_CONTROLLERS \ - if (cancel_computation()) \ - continue; \ - pause_computation(); - /** @brief A generic learning machine interface. * * A machine takes as input CFeatures and CLabels (by default). @@ -147,7 +142,7 @@ enum EProblemType * locking. * */ -class CMachine : public CSGObject +class CMachine : public CStoppableSGObject { public: /** constructor */ @@ -313,37 +308,6 @@ class CMachine : public CSGObject return PT_BINARY; } -#ifndef SWIG - /** @return whether the algorithm needs to be stopped */ - SG_FORCED_INLINE bool cancel_computation() const - { - return m_cancel_computation.load(); - } -#endif - -#ifndef SWIG - /** Pause the algorithm if the flag is set */ - SG_FORCED_INLINE void pause_computation() - { - if (m_pause_computation_flag.load()) - { - std::unique_lock lck(m_mutex); - while (m_pause_computation_flag.load()) - m_pause_computation.wait(lck); - } - } -#endif - -#ifndef SWIG - /** Resume current computation (sets the flag) */ - SG_FORCED_INLINE void resume_computation() - { - std::unique_lock lck(m_mutex); - m_pause_computation_flag = false; - m_pause_computation.notify_all(); - } -#endif - virtual const char* get_name() const { return "Machine"; } protected: @@ -395,38 +359,6 @@ class CMachine : public CSGObject /** returns whether machine require labels for training */ virtual bool train_require_labels() const { return true; } - /** connect the machine instance to the signal handler */ - rxcpp::subscription connect_to_signal_handler(); - - /** reset the computation variables */ - void reset_computation_variables() - { - m_cancel_computation = false; - m_pause_computation_flag = false; - } - - /** The action which will be done when the user decides to - * premature stop the CMachine execution */ - virtual void on_next() - { - m_cancel_computation.store(true); - } - - /** The action which will be done when the user decides to - * pause the CMachine execution */ - virtual void on_pause() - { - m_pause_computation_flag.store(true); - /* Here there should be the actual code*/ - resume_computation(); - } - - /** The action which will be done when the user decides to - * return to prompt and terminate the program execution */ - virtual void on_complete() - { - } - protected: /** maximum training time */ float64_t m_max_train_time; @@ -442,18 +374,6 @@ class CMachine : public CSGObject /** whether data is locked */ bool m_data_locked; - - /** Cancel computation */ - std::atomic m_cancel_computation; - - /** Pause computation flag */ - std::atomic m_pause_computation_flag; - - /** Conditional variable to make threads wait */ - std::condition_variable m_pause_computation; - - /** Mutex used to pause threads */ - std::mutex m_mutex; }; } #endif // _MACHINE_H__