Skip to content

Commit

Permalink
[PrematureStopping] Improve pause method using std::condition_variabl…
Browse files Browse the repository at this point in the history
…e and locks.

Add COMPUTATION_CONTROLLERS macro to manage pause/cancel.
It will works also inside OpenMP environment.

Guarded from SWIG some methods which use SG_FORCED_INLINE.
  • Loading branch information
geektoni authored and vigsterkr committed Jul 12, 2017
1 parent 7e3a23f commit d0a58e8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -19,7 +19,7 @@ using namespace shogun;

CMachine::CMachine()
: CSGObject(), m_max_train_time(0), m_labels(NULL), m_solver_type(ST_AUTO),
m_cancel_computation(false), m_pause_computation(false)
m_cancel_computation(false)
{
m_data_locked=false;
m_store_model_features=false;
Expand Down Expand Up @@ -61,8 +61,8 @@ bool CMachine::train(CFeatures* data)

auto sub = connect_to_signal_handler();
bool result = train_machine(data);
sub.unsubscribe();
reset_computation_variables();
sub.unsubscribe();
reset_computation_variables();

if (m_store_model_features)
store_model_features();
Expand Down
50 changes: 39 additions & 11 deletions src/shogun/machine/Machine.h
Expand Up @@ -23,6 +23,8 @@
#include <shogun/labels/LatentLabels.h>
#include <shogun/features/Features.h>

#include <condition_variable>
#include <mutex>
#include <rxcpp/rx.hpp>

namespace shogun
Expand Down Expand Up @@ -125,6 +127,11 @@ enum EProblemType
*/ \
virtual EProblemType get_machine_problem_type() const { return PT; }

#define COMPUTATION_CONTROLLERS \
if (cancel_computation()) \
continue; \
pause_computation();

/** @brief A generic learning machine interface.
*
* A machine takes as input CFeatures and CLabels (by default).
Expand Down Expand Up @@ -308,23 +315,36 @@ class CMachine : public CSGObject
return PT_BINARY;
}

#ifndef SWIG
/** @return whether the algorithm needs to be stopped */
inline bool cancel_computation() const
SG_FORCED_INLINE bool cancel_computation() const
{
return m_cancel_computation.load();
}
#endif

/** @return whether the algorithm needs to be paused */
inline bool pause_computation() const
#ifndef SWIG
/** Pause the algorithm if the flag is set */
SG_FORCED_INLINE void pause_computation()
{
return m_pause_computation.load();
if (m_pause_computation_flag.load())
{
std::unique_lock<std::mutex> lck(m_mutex);
while (m_pause_computation_flag.load())
m_pause_computation.wait(lck);
}
}
#endif

/** Unpause current computation (sets the flag) */
inline void unpause_computation()
#ifndef SWIG
/** Resume current computation (sets the flag) */
SG_FORCED_INLINE void resume_computation()
{
m_pause_computation = false;
std::unique_lock<std::mutex> lck(m_mutex);
m_pause_computation_flag = false;
m_pause_computation.notify_all();
}
#endif

virtual const char* get_name() const { return "Machine"; }

Expand Down Expand Up @@ -384,7 +404,7 @@ class CMachine : public CSGObject
void reset_computation_variables()
{
m_cancel_computation = false;
m_pause_computation = false;
m_pause_computation_flag = false;
}

/** The action which will be done when the user decides to
Expand All @@ -398,7 +418,9 @@ class CMachine : public CSGObject
* pause the CMachine execution */
virtual void on_pause()
{
m_pause_computation.store(true);
m_pause_computation_flag.store(true);
/* Here there should be the actual code*/
resume_computation();
}

/** The action which will be done when the user decides to
Expand Down Expand Up @@ -426,8 +448,14 @@ class CMachine : public CSGObject
/** Cancel computation */
std::atomic<bool> m_cancel_computation;

/** Pause computation */
std::atomic<bool> m_pause_computation;
/** Pause computation flag */
std::atomic<bool> m_pause_computation_flag;

/** Conditional variable to make threads wait */
std::condition_variable m_pause_computation;

/** Mutex used to pause threads */
std::mutex m_mutex;
};
}
#endif // _MACHINE_H__

0 comments on commit d0a58e8

Please sign in to comment.