Skip to content

Commit

Permalink
CStoppableSGObject class (#4280)
Browse files Browse the repository at this point in the history
* StoppableSGObject class
* add base class to swig
  • Loading branch information
shubham808 authored and karlnapf committed May 11, 2018
1 parent bf67562 commit 986f97e
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 98 deletions.
2 changes: 2 additions & 0 deletions src/interfaces/swig/SGBase.i
Expand Up @@ -109,6 +109,7 @@ public void readExternal(java.io.ObjectInput in) throws java.io.IOException, jav
#include <shogun/base/Version.h>
#include <shogun/base/Parallel.h>
#include <shogun/base/SGObject.h>
#include <shogun/lib/StoppableSGObject.h>

extern void sg_global_print_message(FILE* target, const char* str);
extern void sg_global_print_warning(FILE* target, const char* str);
Expand Down Expand Up @@ -327,6 +328,7 @@ namespace std {
%include <shogun/io/SGIO.h>
%include <shogun/base/Version.h>
%include <shogun/base/Parallel.h>
%include <shogun/lib/StoppableSGObject.h>

#ifdef SWIGPYTHON
namespace shogun
Expand Down
74 changes: 74 additions & 0 deletions src/shogun/lib/StoppableSGObject.cpp
@@ -0,0 +1,74 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Shubham Shukla
*/

#include <rxcpp/rx-lite.hpp>
#include <shogun/base/init.h>
#include <shogun/lib/Signal.h>
#include <shogun/lib/StoppableSGObject.h>

using namespace shogun;

#define COMPUTATION_CONTROLLERS \
if (cancel_computation()) \
continue; \
pause_computation();

CStoppableSGObject::CStoppableSGObject() : CSGObject(){};

CStoppableSGObject::~CStoppableSGObject(){};

void CStoppableSGObject::init_stoppable()
{
m_cancel_computation = false;
m_pause_computation_flag = false;
}

rxcpp::subscription CStoppableSGObject::connect_to_signal_handler()
{
// Subscribe this algorithm to the signal handler
auto subscriber = rxcpp::make_subscriber<int>(
[this](int i) {
if (i == SG_PAUSE_COMP)
this->on_pause();
else
this->on_next();
},
[this]() { this->on_complete(); });
return get_global_signal()->get_observable()->subscribe(subscriber);
}

void CStoppableSGObject::reset_computation_variables()
{
m_cancel_computation = false;
m_pause_computation_flag = false;
}

void CStoppableSGObject::on_next()
{
m_cancel_computation.store(true);
on_next_impl();
}

void CStoppableSGObject::on_pause()
{
m_pause_computation_flag.store(true);
on_pause_impl();
resume_computation();
}

void CStoppableSGObject::on_complete()
{
on_complete_impl();
}
void CStoppableSGObject::on_next_impl()
{
}
void CStoppableSGObject::on_pause_impl()
{
}
void CStoppableSGObject::on_complete_impl()
{
}
113 changes: 113 additions & 0 deletions src/shogun/lib/StoppableSGObject.h
@@ -0,0 +1,113 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Shubham Shukla
*/

#ifndef __STOPPABLESGOBJECT_H_
#define __STOPPABLESGOBJECT_H_

#include <shogun/base/SGObject.h>
#include <shogun/base/init.h>

#include <condition_variable>
#include <mutex>

namespace shogun
{
#define COMPUTATION_CONTROLLERS \
if (cancel_computation()) \
continue; \
pause_computation();

/**
* Class that abstracts all premature stopping code
*/
class CStoppableSGObject : public CSGObject
{
public:
/** constructor */
CStoppableSGObject();

/** destructor */
virtual ~CStoppableSGObject();

/** init flags to false as default */
void init_stoppable();

#ifndef SWIG
/** @return whether the algorithm needs to be stopped */
SG_FORCED_INLINE bool cancel_computation() const
{
return m_cancel_computation.load();
}
#endif

#ifndef SWIG
/** Pause the algorithm if the flag is set */
SG_FORCED_INLINE void pause_computation()
{
if (m_pause_computation_flag.load())
{
std::unique_lock<std::mutex> lck(m_mutex);
while (m_pause_computation_flag.load())
m_pause_computation.wait(lck);
}
}
#endif

#ifndef SWIG
/** Resume current computation (sets the flag) */
SG_FORCED_INLINE void resume_computation()
{
std::unique_lock<std::mutex> lck(m_mutex);
m_pause_computation_flag = false;
m_pause_computation.notify_all();
}
#endif

virtual const char* get_name() const
{
return "StoppableSGObject";
}

protected:
/** connect the machine instance to the signal handler */
rxcpp::subscription connect_to_signal_handler();

/** reset the computation variables */
void reset_computation_variables();

/** sets cancel computation flag */
void on_next();

/** The action which will be done when the user decides to
* premature stop the CMachine execution */
virtual void on_next_impl();

/** sets pause computation flag and resumes after action is complete */
void on_pause();

/** The action which will be done when the user decides to
* pause the CMachine execution */
virtual void on_pause_impl();

/** These actions which will be done when the user decides to
* return to prompt and terminate the program execution */
void on_complete();
virtual void on_complete_impl();

/** Cancel computation */
std::atomic<bool> m_cancel_computation;

/** Pause computation flag */
std::atomic<bool> m_pause_computation_flag;

/** Conditional variable to make threads wait */
std::condition_variable m_pause_computation;

/** Mutex used to pause threads */
std::mutex m_mutex;
};
}
#endif
19 changes: 3 additions & 16 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -14,9 +14,10 @@
using namespace shogun;

CMachine::CMachine()
: CSGObject(), m_max_train_time(0), m_labels(NULL), m_solver_type(ST_AUTO),
m_cancel_computation(false), m_pause_computation_flag(false)
: CStoppableSGObject(), m_max_train_time(0), m_labels(NULL),
m_solver_type(ST_AUTO)
{
init_stoppable();
m_data_locked=false;
m_store_model_features=false;

Expand Down Expand Up @@ -273,17 +274,3 @@ CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices)
"for %s\n", get_name());
return NULL;
}

rxcpp::subscription CMachine::connect_to_signal_handler()
{
// Subscribe this algorithm to the signal handler
auto subscriber = rxcpp::make_subscriber<int>(
[this](int i) {
if (i == SG_PAUSE_COMP)
this->on_pause();
else
this->on_next();
},
[this]() { this->on_complete(); });
return get_global_signal()->get_observable()->subscribe(subscriber);
}
84 changes: 2 additions & 82 deletions src/shogun/machine/Machine.h
Expand Up @@ -11,14 +11,14 @@
#ifndef _MACHINE_H__
#define _MACHINE_H__

#include <shogun/base/SGObject.h>
#include <shogun/base/class_list.h>
#include <shogun/features/Features.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/LatentLabels.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/labels/StructuredLabels.h>
#include <shogun/lib/StoppableSGObject.h>
#include <shogun/lib/common.h>
#include <shogun/lib/config.h>

Expand Down Expand Up @@ -125,11 +125,6 @@ enum EProblemType
*/ \
virtual EProblemType get_machine_problem_type() const { return PT; }

#define COMPUTATION_CONTROLLERS \
if (cancel_computation()) \
continue; \
pause_computation();

/** @brief A generic learning machine interface.
*
* A machine takes as input CFeatures and CLabels (by default).
Expand All @@ -147,7 +142,7 @@ enum EProblemType
* locking.
*
*/
class CMachine : public CSGObject
class CMachine : public CStoppableSGObject
{
public:
/** constructor */
Expand Down Expand Up @@ -313,37 +308,6 @@ class CMachine : public CSGObject
return PT_BINARY;
}

#ifndef SWIG
/** @return whether the algorithm needs to be stopped */
SG_FORCED_INLINE bool cancel_computation() const
{
return m_cancel_computation.load();
}
#endif

#ifndef SWIG
/** Pause the algorithm if the flag is set */
SG_FORCED_INLINE void pause_computation()
{
if (m_pause_computation_flag.load())
{
std::unique_lock<std::mutex> lck(m_mutex);
while (m_pause_computation_flag.load())
m_pause_computation.wait(lck);
}
}
#endif

#ifndef SWIG
/** Resume current computation (sets the flag) */
SG_FORCED_INLINE void resume_computation()
{
std::unique_lock<std::mutex> lck(m_mutex);
m_pause_computation_flag = false;
m_pause_computation.notify_all();
}
#endif

virtual const char* get_name() const { return "Machine"; }

protected:
Expand Down Expand Up @@ -395,38 +359,6 @@ class CMachine : public CSGObject
/** returns whether machine require labels for training */
virtual bool train_require_labels() const { return true; }

/** connect the machine instance to the signal handler */
rxcpp::subscription connect_to_signal_handler();

/** reset the computation variables */
void reset_computation_variables()
{
m_cancel_computation = false;
m_pause_computation_flag = false;
}

/** The action which will be done when the user decides to
* premature stop the CMachine execution */
virtual void on_next()
{
m_cancel_computation.store(true);
}

/** The action which will be done when the user decides to
* pause the CMachine execution */
virtual void on_pause()
{
m_pause_computation_flag.store(true);
/* Here there should be the actual code*/
resume_computation();
}

/** The action which will be done when the user decides to
* return to prompt and terminate the program execution */
virtual void on_complete()
{
}

protected:
/** maximum training time */
float64_t m_max_train_time;
Expand All @@ -442,18 +374,6 @@ class CMachine : public CSGObject

/** whether data is locked */
bool m_data_locked;

/** Cancel computation */
std::atomic<bool> m_cancel_computation;

/** Pause computation flag */
std::atomic<bool> m_pause_computation_flag;

/** Conditional variable to make threads wait */
std::condition_variable m_pause_computation;

/** Mutex used to pause threads */
std::mutex m_mutex;
};
}
#endif // _MACHINE_H__

0 comments on commit 986f97e

Please sign in to comment.