Skip to content

Commit

Permalink
[PrematureStopping] Refactor CSignal class to use RxCpp utilities.
Browse files Browse the repository at this point in the history
Add basic unit tests.
  • Loading branch information
geektoni authored and vigsterkr committed Jun 30, 2017
1 parent 32a46e6 commit 44c3bd6
Show file tree
Hide file tree
Showing 29 changed files with 137 additions and 171 deletions.
4 changes: 1 addition & 3 deletions src/shogun/classifier/AveragedPerceptron.cpp
Expand Up @@ -60,15 +60,13 @@ bool CAveragedPerceptron::train_machine(CFeatures* data)
for (int32_t i=0; i<num_feat; i++)
w[i]=1.0/num_feat;

CSignal::clear_cancel();

//loop till we either get everything classified right or reach max_iter

while (!(CSignal::cancel_computations()) && (!converged && iter<max_iter))
{
converged=true;
SG_INFO("Iteration Number : %d of max %d\n", iter, max_iter);

for (int32_t i=0; i<num_vec; i++)
{
output[i] = features->dense_dot(i, w.vector, w.vlen) + bias;
Expand Down
1 change: 0 additions & 1 deletion src/shogun/classifier/LPBoost.cpp
Expand Up @@ -124,7 +124,6 @@ bool CLPBoost::train_machine(CFeatures* data)

int32_t num_hypothesis=0;
CTime time;
CSignal::clear_cancel();

while (!(CSignal::cancel_computations()))
{
Expand Down
1 change: 0 additions & 1 deletion src/shogun/classifier/Perceptron.cpp
Expand Up @@ -65,7 +65,6 @@ bool CPerceptron::train_machine(CFeatures* data)
w.vector[i]=1.0/num_feat;
}

CSignal::clear_cancel();

//loop till we either get everything classified right or reach max_iter
while (!(CSignal::cancel_computations()) && (!converged && iter<max_iter))
Expand Down
1 change: 0 additions & 1 deletion src/shogun/classifier/mkl/MKL.cpp
Expand Up @@ -427,7 +427,6 @@ bool CMKL::train_machine(CFeatures* data)
#endif

mkl_iterations = 0;
CSignal::clear_cancel();

training_time_clock.start();

Expand Down
1 change: 0 additions & 1 deletion src/shogun/classifier/mkl/MKLMulticlass.cpp
Expand Up @@ -370,7 +370,6 @@ bool CMKLMulticlass::train_machine(CFeatures* data)

int32_t numberofsilpiterations=0;
bool final=false;
CSignal::clear_cancel();

while (!(CSignal::cancel_computations()) && !final)
{
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/LibLinear.cpp
Expand Up @@ -78,7 +78,7 @@ CLibLinear::~CLibLinear()

bool CLibLinear::train_machine(CFeatures* data)
{
CSignal::clear_cancel();
ASSERT(m_labels)
ASSERT(m_labels->get_label_type() == LT_BINARY)

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/MPDSVM.cpp
Expand Up @@ -79,7 +79,7 @@ bool CMPDSVM::train_machine(CFeatures* data)

bool primalcool;
bool dualcool;
CSignal::clear_cancel();

//if (nustop)
//etas[1] = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/NewtonSVM.cpp
Expand Up @@ -49,7 +49,7 @@ CNewtonSVM::~CNewtonSVM()

bool CNewtonSVM::train_machine(CFeatures* data)
{
CSignal::clear_cancel();
ASSERT(m_labels)
ASSERT(m_labels->get_label_type() == LT_BINARY)

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/OnlineSVMSGD.cpp
Expand Up @@ -100,7 +100,7 @@ bool COnlineSVMSGD::train(CFeatures* data)
if (features->is_seekable())
features->reset_stream();

CSignal::clear_cancel();

ELossType loss_type = loss->get_loss_type();
bool is_log_loss = false;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/SGDQN.cpp
Expand Up @@ -137,7 +137,7 @@ bool CSGDQN::train(CFeatures* data)
calibrate();

SG_INFO("Training on %d vectors\n", num_vec)
CSignal::clear_cancel();

ELossType loss_type = loss->get_loss_type();
bool is_log_loss = false;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/SVMLight.cpp
Expand Up @@ -646,7 +646,7 @@ int32_t CSVMLight::optimize_to_convergence(int32_t* docs, int32_t* label, int32_
#ifdef CYGWIN
for (;((iteration<100 || (!mkl_converged && callback) ) || (retrain && (!terminate))); iteration++){
#else
CSignal::clear_cancel();
for (;((!CSignal::cancel_computations()) && ((iteration<3 || (!mkl_converged && callback) ) || (retrain && (!terminate)))); iteration++){
#endif

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/SVMSGD.cpp
Expand Up @@ -109,7 +109,7 @@ bool CSVMSGD::train_machine(CFeatures* data)
calibrate();

SG_INFO("Training on %d vectors\n", num_vec)
CSignal::clear_cancel();

ELossType loss_type = loss->get_loss_type();
bool is_log_loss = false;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/vw/VowpalWabbit.cpp
Expand Up @@ -163,7 +163,7 @@ bool CVowpalWabbit::train_machine(CFeatures* feat)
"loss", "last", "counter", "weight", "label", "predict", "features");
}

CSignal::clear_cancel();
features->start_parser();
while (!(CSignal::cancel_computations()) && (env->passes_complete < env->num_passes))
{
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/features/DotFeatures.cpp
Expand Up @@ -62,7 +62,7 @@ void CDotFeatures::dense_dot_range(float64_t* output, int32_t start, int32_t sto
int32_t num_vectors=stop-start;
ASSERT(num_vectors>0)

CSignal::clear_cancel();

int32_t num_threads;
int32_t step;
Expand Down Expand Up @@ -113,7 +113,7 @@ void CDotFeatures::dense_dot_range_subset(int32_t* sub_index, int32_t num, float
ASSERT(sub_index)
ASSERT(output)

CSignal::clear_cancel();

auto pb = progress(range(num), *this->io);
int32_t num_threads;
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/features/hashed/HashedWDFeaturesTransposed.cpp
Expand Up @@ -222,7 +222,7 @@ void CHashedWDFeaturesTransposed::dense_dot_range(float64_t* output, int32_t sta
#endif
ASSERT(num_threads>0)

CSignal::clear_cancel();

if (dim != w_dim)
SG_ERROR("Dimensions don't match, vec_len=%d, w_dim=%d\n", dim, w_dim)
Expand Down Expand Up @@ -315,7 +315,7 @@ void CHashedWDFeaturesTransposed::dense_dot_range_subset(int32_t* sub_index, int
#endif
ASSERT(num_threads>0)

CSignal::clear_cancel();

if (dim != w_dim)
SG_ERROR("Dimensions don't match, vec_len=%d, w_dim=%d\n", dim, w_dim)
Expand Down
Expand Up @@ -1252,7 +1252,7 @@ void CWeightedDegreePositionStringKernel::compute_batch(

if (num_threads < 2)
{
CSignal::clear_cancel();
auto pb = progress(range(num_feat), *this->io);
for (int32_t j=0; j<num_feat && !CSignal::cancel_computations(); j++)
{
Expand Down Expand Up @@ -1281,7 +1281,7 @@ void CWeightedDegreePositionStringKernel::compute_batch(
else
{

CSignal::clear_cancel();
auto pb = progress(range(num_feat), *this->io);
for (int32_t j=0; j<num_feat && !CSignal::cancel_computations(); j++)
{
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/kernel/string/WeightedDegreeStringKernel.cpp
Expand Up @@ -884,7 +884,7 @@ void CWeightedDegreeStringKernel::compute_batch(

if (num_threads < 2)
{
CSignal::clear_cancel();
for (int32_t j=0; j<num_feat && !CSignal::cancel_computations(); j++)
{
init_optimization(num_suppvec, IDX, alphas, j);
Expand All @@ -909,7 +909,7 @@ void CWeightedDegreeStringKernel::compute_batch(
#ifdef HAVE_PTHREAD
else
{
CSignal::clear_cancel();
for (int32_t j=0; j<num_feat && !CSignal::cancel_computations(); j++)
{
init_optimization(num_suppvec, IDX, alphas, j);
Expand Down
1 change: 0 additions & 1 deletion src/shogun/lib/ShogunException.cpp
Expand Up @@ -35,7 +35,6 @@ ShogunException::init(const char* str)
ShogunException::ShogunException(const char* str)
{
#ifndef WIN32
CSignal::unset_handler();
#endif

init(str);
Expand Down
149 changes: 48 additions & 101 deletions src/shogun/lib/Signal.cpp
Expand Up @@ -11,137 +11,84 @@
#include <shogun/lib/config.h>

#include <stdlib.h>
#include <string.h>

#include <shogun/io/SGIO.h>
#include <shogun/lib/Signal.h>
#include <shogun/base/init.h>
#include <rxcpp/rx-includes.hpp>
#include <rxcpp/rx.hpp>

using namespace shogun;
using namespace rxcpp;

int CSignal::signals[NUMTRAPPEDSIGS]={SIGINT, SIGURG};
struct sigaction CSignal::oldsigaction[NUMTRAPPEDSIGS];
bool CSignal::active=false;
bool CSignal::cancel_computation=false;
bool CSignal::cancel_immediately=false;

rxcpp::connectable_observable<int> CSignal::m_sigint_observable = rxcpp::observable<>::create<int>(
[](rxcpp::subscriber<int> s){
s.on_completed();
}
).publish();

rxcpp::connectable_observable<int> CSignal::m_sigurg_observable = rxcpp::observable<>::create<int>(
[](rxcpp::subscriber<int> s){
s.on_next(1);
}
).publish();

CSignal::CSignal()
: CSGObject()
{
// Set if the signal handler is active or not
m_active = true;
}

CSignal::~CSignal()
CSignal::CSignal(bool active)
: CSGObject()
{
if (!unset_handler())
SG_PRINT("error uninitalizing signal handler\n")
// Set if the signal handler is active or not
m_active = active;
}

void CSignal::handler(int signal)
CSignal::~CSignal()
{
if (signal == SIGINT)
{
SG_SPRINT("\nImmediately return to prompt / Prematurely finish computations / Do nothing (I/P/D)? ")
char answer=fgetc(stdin);

if (answer == 'I')
{
unset_handler();
set_cancel(true);
if (sg_print_error)
sg_print_error(stdout, "sg stopped by SIGINT\n");
}
else if (answer == 'P')
set_cancel();
else
SG_SPRINT("Continuing...\n")
}
else if (signal == SIGURG)
set_cancel();
else
SG_SPRINT("unknown signal %d received\n", signal)
}

bool CSignal::set_handler()
rxcpp::connectable_observable<int> CSignal::get_SIGINT_observable()
{
if (!active)
{
struct sigaction act;
sigset_t st;

sigemptyset(&st);
for (int32_t i=0; i<NUMTRAPPEDSIGS; i++)
sigaddset(&st, signals[i]);

#if !(defined(__INTERIX) || defined(__MINGW64__) || defined(_MSC_VER) || defined(__MINGW32__))
act.sa_sigaction=NULL; //just in case
#endif
act.sa_handler=CSignal::handler;
act.sa_mask = st;
act.sa_flags = 0;

for (int32_t i=0; i<NUMTRAPPEDSIGS; i++)
{
if (sigaction(signals[i], &act, &oldsigaction[i]))
{
SG_SPRINT("Error trapping signals!\n")
for (int32_t j=i-1; j>=0; j--)
sigaction(signals[i], &oldsigaction[i], NULL);

clear();
return false;
}
}
return m_sigint_observable;
}

active=true;
return true;
}
else
return false;
rxcpp::connectable_observable<int> CSignal::get_SIGURG_observable()
{
return m_sigurg_observable;
}

bool CSignal::unset_handler()
void CSignal::handler(int signal)
{
if (active)
if (signal == SIGINT)
{
bool result=true;

for (int32_t i=0; i<NUMTRAPPEDSIGS; i++)
{
if (sigaction(signals[i], &oldsigaction[i], NULL))
{
SG_SPRINT("error uninitalizing signal handler for signal %d\n", signals[i])
result=false;
}
}

if (result)
clear();
//SG_SPRINT("\nImmediately return to prompt / Prematurely finish computations / Do nothing (I/P/D)? ")
//char answer=fgetc(stdin);
/*switch (answer){
case 'I':
m_sigint_observable.connect();
break;
case 'P':
m_sigurg_observable.connect();
break;
default:
SG_SPRINT("Continuing...\n")
break;
}*/
SG_SPRINT("Killing the application...\n");
m_sigint_observable.connect();

return result;
}
else if (signal == SIGURG)
m_sigurg_observable.connect();
else
return false;
}

void CSignal::clear_cancel()
{
cancel_computation=false;
cancel_immediately=false;
}

void CSignal::set_cancel(bool immediately)
{
cancel_computation=true;

if (immediately)
cancel_immediately=true;
}

void CSignal::clear()
{
clear_cancel();
active=false;
memset(&CSignal::oldsigaction, 0, sizeof(CSignal::oldsigaction));
SG_SPRINT("unknown signal %d received\n", signal)
}

#if defined(__MINGW64__) || defined(_MSC_VER) || defined(__MINGW32__)
Expand Down

0 comments on commit 44c3bd6

Please sign in to comment.