From 44c3bd6dea6d71af4366fbb69e3856974cfa8fc4 Mon Sep 17 00:00:00 2001 From: Giovanni De Toni Date: Wed, 14 Jun 2017 17:59:52 +0200 Subject: [PATCH] [PrematureStopping] Refactor CSignal class to use RxCpp utilities. Add basic unit tests. --- src/shogun/classifier/AveragedPerceptron.cpp | 4 +- src/shogun/classifier/LPBoost.cpp | 1 - src/shogun/classifier/Perceptron.cpp | 1 - src/shogun/classifier/mkl/MKL.cpp | 1 - src/shogun/classifier/mkl/MKLMulticlass.cpp | 1 - src/shogun/classifier/svm/LibLinear.cpp | 2 +- src/shogun/classifier/svm/MPDSVM.cpp | 2 +- src/shogun/classifier/svm/NewtonSVM.cpp | 2 +- src/shogun/classifier/svm/OnlineSVMSGD.cpp | 2 +- src/shogun/classifier/svm/SGDQN.cpp | 2 +- src/shogun/classifier/svm/SVMLight.cpp | 2 +- src/shogun/classifier/svm/SVMSGD.cpp | 2 +- src/shogun/classifier/vw/VowpalWabbit.cpp | 2 +- src/shogun/features/DotFeatures.cpp | 4 +- .../hashed/HashedWDFeaturesTransposed.cpp | 4 +- .../WeightedDegreePositionStringKernel.cpp | 4 +- .../string/WeightedDegreeStringKernel.cpp | 4 +- src/shogun/lib/ShogunException.cpp | 1 - src/shogun/lib/Signal.cpp | 149 ++++++------------ src/shogun/lib/Signal.h | 54 +++---- src/shogun/lib/external/shogun_libsvm.cpp | 1 - src/shogun/machine/KernelMachine.cpp | 4 +- src/shogun/multiclass/KNN.cpp | 6 +- src/shogun/multiclass/LaRank.cpp | 2 +- src/shogun/multiclass/MulticlassLibLinear.cpp | 2 +- src/shogun/optimization/liblinear/tron.cpp | 2 +- .../regression/svr/LibLinearRegression.cpp | 2 +- .../transfer/multitask/LibLinearMTL.cpp | 2 +- tests/unit/lib/Signal_unittest.cc | 43 +++++ 29 files changed, 137 insertions(+), 171 deletions(-) create mode 100644 tests/unit/lib/Signal_unittest.cc diff --git a/src/shogun/classifier/AveragedPerceptron.cpp b/src/shogun/classifier/AveragedPerceptron.cpp index 35d89daa710..c0b09871bc3 100644 --- a/src/shogun/classifier/AveragedPerceptron.cpp +++ b/src/shogun/classifier/AveragedPerceptron.cpp @@ -60,15 +60,13 @@ bool CAveragedPerceptron::train_machine(CFeatures* data) for (int32_t i=0; idense_dot(i, w.vector, w.vlen) + bias; diff --git a/src/shogun/classifier/LPBoost.cpp b/src/shogun/classifier/LPBoost.cpp index b1b9bd5618c..df2dc473e9c 100644 --- a/src/shogun/classifier/LPBoost.cpp +++ b/src/shogun/classifier/LPBoost.cpp @@ -124,7 +124,6 @@ bool CLPBoost::train_machine(CFeatures* data) int32_t num_hypothesis=0; CTime time; - CSignal::clear_cancel(); while (!(CSignal::cancel_computations())) { diff --git a/src/shogun/classifier/Perceptron.cpp b/src/shogun/classifier/Perceptron.cpp index 56c67ed7757..e724d05feac 100644 --- a/src/shogun/classifier/Perceptron.cpp +++ b/src/shogun/classifier/Perceptron.cpp @@ -65,7 +65,6 @@ bool CPerceptron::train_machine(CFeatures* data) w.vector[i]=1.0/num_feat; } - CSignal::clear_cancel(); //loop till we either get everything classified right or reach max_iter while (!(CSignal::cancel_computations()) && (!converged && iterget_label_type() == LT_BINARY) diff --git a/src/shogun/classifier/svm/MPDSVM.cpp b/src/shogun/classifier/svm/MPDSVM.cpp index 33c7a60f9cf..51ac838d901 100644 --- a/src/shogun/classifier/svm/MPDSVM.cpp +++ b/src/shogun/classifier/svm/MPDSVM.cpp @@ -79,7 +79,7 @@ bool CMPDSVM::train_machine(CFeatures* data) bool primalcool; bool dualcool; - CSignal::clear_cancel(); + //if (nustop) //etas[1] = 1; diff --git a/src/shogun/classifier/svm/NewtonSVM.cpp b/src/shogun/classifier/svm/NewtonSVM.cpp index 5431d6e3aa0..2ff073b9dc7 100644 --- a/src/shogun/classifier/svm/NewtonSVM.cpp +++ b/src/shogun/classifier/svm/NewtonSVM.cpp @@ -49,7 +49,7 @@ CNewtonSVM::~CNewtonSVM() bool CNewtonSVM::train_machine(CFeatures* data) { - CSignal::clear_cancel(); + ASSERT(m_labels) ASSERT(m_labels->get_label_type() == LT_BINARY) diff --git a/src/shogun/classifier/svm/OnlineSVMSGD.cpp b/src/shogun/classifier/svm/OnlineSVMSGD.cpp index 6e7495b97b9..9e83b3645ff 100644 --- a/src/shogun/classifier/svm/OnlineSVMSGD.cpp +++ b/src/shogun/classifier/svm/OnlineSVMSGD.cpp @@ -100,7 +100,7 @@ bool COnlineSVMSGD::train(CFeatures* data) if (features->is_seekable()) features->reset_stream(); - CSignal::clear_cancel(); + ELossType loss_type = loss->get_loss_type(); bool is_log_loss = false; diff --git a/src/shogun/classifier/svm/SGDQN.cpp b/src/shogun/classifier/svm/SGDQN.cpp index e005485b7cc..bdc856243f9 100644 --- a/src/shogun/classifier/svm/SGDQN.cpp +++ b/src/shogun/classifier/svm/SGDQN.cpp @@ -137,7 +137,7 @@ bool CSGDQN::train(CFeatures* data) calibrate(); SG_INFO("Training on %d vectors\n", num_vec) - CSignal::clear_cancel(); + ELossType loss_type = loss->get_loss_type(); bool is_log_loss = false; diff --git a/src/shogun/classifier/svm/SVMLight.cpp b/src/shogun/classifier/svm/SVMLight.cpp index c312cc91836..93a1cb36cb5 100644 --- a/src/shogun/classifier/svm/SVMLight.cpp +++ b/src/shogun/classifier/svm/SVMLight.cpp @@ -646,7 +646,7 @@ int32_t CSVMLight::optimize_to_convergence(int32_t* docs, int32_t* label, int32_ #ifdef CYGWIN for (;((iteration<100 || (!mkl_converged && callback) ) || (retrain && (!terminate))); iteration++){ #else - CSignal::clear_cancel(); + for (;((!CSignal::cancel_computations()) && ((iteration<3 || (!mkl_converged && callback) ) || (retrain && (!terminate)))); iteration++){ #endif diff --git a/src/shogun/classifier/svm/SVMSGD.cpp b/src/shogun/classifier/svm/SVMSGD.cpp index f608ff3eef0..39130d210a1 100644 --- a/src/shogun/classifier/svm/SVMSGD.cpp +++ b/src/shogun/classifier/svm/SVMSGD.cpp @@ -109,7 +109,7 @@ bool CSVMSGD::train_machine(CFeatures* data) calibrate(); SG_INFO("Training on %d vectors\n", num_vec) - CSignal::clear_cancel(); + ELossType loss_type = loss->get_loss_type(); bool is_log_loss = false; diff --git a/src/shogun/classifier/vw/VowpalWabbit.cpp b/src/shogun/classifier/vw/VowpalWabbit.cpp index 6d5928f1762..2c32ac8ad2c 100644 --- a/src/shogun/classifier/vw/VowpalWabbit.cpp +++ b/src/shogun/classifier/vw/VowpalWabbit.cpp @@ -163,7 +163,7 @@ bool CVowpalWabbit::train_machine(CFeatures* feat) "loss", "last", "counter", "weight", "label", "predict", "features"); } - CSignal::clear_cancel(); + features->start_parser(); while (!(CSignal::cancel_computations()) && (env->passes_complete < env->num_passes)) { diff --git a/src/shogun/features/DotFeatures.cpp b/src/shogun/features/DotFeatures.cpp index 597521a60bf..fb2c69032c9 100644 --- a/src/shogun/features/DotFeatures.cpp +++ b/src/shogun/features/DotFeatures.cpp @@ -62,7 +62,7 @@ void CDotFeatures::dense_dot_range(float64_t* output, int32_t start, int32_t sto int32_t num_vectors=stop-start; ASSERT(num_vectors>0) - CSignal::clear_cancel(); + int32_t num_threads; int32_t step; @@ -113,7 +113,7 @@ void CDotFeatures::dense_dot_range_subset(int32_t* sub_index, int32_t num, float ASSERT(sub_index) ASSERT(output) - CSignal::clear_cancel(); + auto pb = progress(range(num), *this->io); int32_t num_threads; diff --git a/src/shogun/features/hashed/HashedWDFeaturesTransposed.cpp b/src/shogun/features/hashed/HashedWDFeaturesTransposed.cpp index 80e9f554bfc..bb5c8628435 100644 --- a/src/shogun/features/hashed/HashedWDFeaturesTransposed.cpp +++ b/src/shogun/features/hashed/HashedWDFeaturesTransposed.cpp @@ -222,7 +222,7 @@ void CHashedWDFeaturesTransposed::dense_dot_range(float64_t* output, int32_t sta #endif ASSERT(num_threads>0) - CSignal::clear_cancel(); + if (dim != w_dim) SG_ERROR("Dimensions don't match, vec_len=%d, w_dim=%d\n", dim, w_dim) @@ -315,7 +315,7 @@ void CHashedWDFeaturesTransposed::dense_dot_range_subset(int32_t* sub_index, int #endif ASSERT(num_threads>0) - CSignal::clear_cancel(); + if (dim != w_dim) SG_ERROR("Dimensions don't match, vec_len=%d, w_dim=%d\n", dim, w_dim) diff --git a/src/shogun/kernel/string/WeightedDegreePositionStringKernel.cpp b/src/shogun/kernel/string/WeightedDegreePositionStringKernel.cpp index 16a3965be6a..4096b769cfd 100644 --- a/src/shogun/kernel/string/WeightedDegreePositionStringKernel.cpp +++ b/src/shogun/kernel/string/WeightedDegreePositionStringKernel.cpp @@ -1252,7 +1252,7 @@ void CWeightedDegreePositionStringKernel::compute_batch( if (num_threads < 2) { - CSignal::clear_cancel(); + auto pb = progress(range(num_feat), *this->io); for (int32_t j=0; jio); for (int32_t j=0; j #include -#include #include #include #include +#include +#include using namespace shogun; +using namespace rxcpp; int CSignal::signals[NUMTRAPPEDSIGS]={SIGINT, SIGURG}; struct sigaction CSignal::oldsigaction[NUMTRAPPEDSIGS]; -bool CSignal::active=false; -bool CSignal::cancel_computation=false; -bool CSignal::cancel_immediately=false; + +rxcpp::connectable_observable CSignal::m_sigint_observable = rxcpp::observable<>::create( + [](rxcpp::subscriber s){ + s.on_completed(); + } +).publish(); + +rxcpp::connectable_observable CSignal::m_sigurg_observable = rxcpp::observable<>::create( + [](rxcpp::subscriber s){ + s.on_next(1); + } +).publish(); CSignal::CSignal() : CSGObject() { + // Set if the signal handler is active or not + m_active = true; } -CSignal::~CSignal() +CSignal::CSignal(bool active) +: CSGObject() { - if (!unset_handler()) - SG_PRINT("error uninitalizing signal handler\n") + // Set if the signal handler is active or not + m_active = active; } -void CSignal::handler(int signal) +CSignal::~CSignal() { - if (signal == SIGINT) - { - SG_SPRINT("\nImmediately return to prompt / Prematurely finish computations / Do nothing (I/P/D)? ") - char answer=fgetc(stdin); - - if (answer == 'I') - { - unset_handler(); - set_cancel(true); - if (sg_print_error) - sg_print_error(stdout, "sg stopped by SIGINT\n"); - } - else if (answer == 'P') - set_cancel(); - else - SG_SPRINT("Continuing...\n") - } - else if (signal == SIGURG) - set_cancel(); - else - SG_SPRINT("unknown signal %d received\n", signal) } -bool CSignal::set_handler() +rxcpp::connectable_observable CSignal::get_SIGINT_observable() { - if (!active) - { - struct sigaction act; - sigset_t st; - - sigemptyset(&st); - for (int32_t i=0; i=0; j--) - sigaction(signals[i], &oldsigaction[i], NULL); - - clear(); - return false; - } - } + return m_sigint_observable; +} - active=true; - return true; - } - else - return false; +rxcpp::connectable_observable CSignal::get_SIGURG_observable() +{ + return m_sigurg_observable; } -bool CSignal::unset_handler() +void CSignal::handler(int signal) { - if (active) + if (signal == SIGINT) { - bool result=true; - - for (int32_t i=0; i +#include #if defined(__MINGW64__) || defined(_MSC_VER) typedef unsigned long sigset_t; @@ -73,36 +74,28 @@ namespace shogun class CSignal : public CSGObject { public: - /** default constructor */ + CSignal(); + CSignal(bool active); virtual ~CSignal(); - /** handler + /** Signal handler. Need to be registered with std::signal. * * @param signal signal number */ static void handler(int signal); - /** set handler - * - * @return if setting was successful - */ - static bool set_handler(); - - /** unset handler - * - * @return if unsetting was successful + /** + * Get SIGINT observable + * @return observable */ - static bool unset_handler(); - - /** clear signals */ - static void clear(); + rxcpp::connectable_observable get_SIGINT_observable(); - /** clear cancel flag signals */ - static void clear_cancel(); - - /** set cancel flag signals */ - static void set_cancel(bool immediately=false); + /** + * Get SIGURG observable + * @ return observable + */ + rxcpp::connectable_observable get_SIGURG_observable(); /** cancel computations * @@ -110,20 +103,14 @@ class CSignal : public CSGObject */ static inline bool cancel_computations() { -#ifndef DISABLE_CANCEL_CALLBACK - if (sg_cancel_computations) - sg_cancel_computations(cancel_computation, cancel_immediately); -#endif - if (cancel_immediately) - throw ShogunException("Computations have been cancelled immediately"); - - return cancel_computation; + return false; } /** @return object name */ virtual const char* get_name() const { return "Signal"; } - protected: + private: + /** signals; handling external lib */ static int signals[NUMTRAPPEDSIGS]; @@ -131,13 +118,10 @@ class CSignal : public CSGObject static struct sigaction oldsigaction[NUMTRAPPEDSIGS]; /** active signal */ - static bool active; - - /** if computation should be cancelled */ - static bool cancel_computation; + bool m_active; - /** if shogun should return ASAP */ - static bool cancel_immediately; + static rxcpp::connectable_observable m_sigint_observable; + static rxcpp::connectable_observable m_sigurg_observable; }; } #endif // __SIGNAL__H_ diff --git a/src/shogun/lib/external/shogun_libsvm.cpp b/src/shogun/lib/external/shogun_libsvm.cpp index ed8914936c4..40609683a6b 100644 --- a/src/shogun/lib/external/shogun_libsvm.cpp +++ b/src/shogun/lib/external/shogun_libsvm.cpp @@ -430,7 +430,6 @@ void Solver::Solve( } // initialize gradient - CSignal::clear_cancel(); CTime start_time; { auto pb = progress(range(l)); diff --git a/src/shogun/machine/KernelMachine.cpp b/src/shogun/machine/KernelMachine.cpp index 0fcc6908b35..b12a7bfcabb 100644 --- a/src/shogun/machine/KernelMachine.cpp +++ b/src/shogun/machine/KernelMachine.cpp @@ -298,7 +298,7 @@ SGVector CKernelMachine::apply_get_outputs(CFeatures* data) { SG_DEBUG("computing output on %d test examples\n", num_vectors) - CSignal::clear_cancel(); + if (io->get_show_progress()) io->enable_progress(); @@ -495,7 +495,7 @@ SGVector CKernelMachine::apply_locked_get_output( int32_t num_inds=indices.vlen; SGVector output(num_inds); - CSignal::clear_cancel(); + if (io->get_show_progress()) io->enable_progress(); diff --git a/src/shogun/multiclass/KNN.cpp b/src/shogun/multiclass/KNN.cpp index 715973911e8..15a51eaea68 100644 --- a/src/shogun/multiclass/KNN.cpp +++ b/src/shogun/multiclass/KNN.cpp @@ -181,7 +181,7 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data) SGVector train_lab(m_k); SG_INFO("%d test examples\n", num_lab) - CSignal::clear_cancel(); + //histogram of classes and returned output SGVector classes(m_num_classes); @@ -207,7 +207,7 @@ CMulticlassLabels* CKNN::classify_NN() SGVector distances(m_train_labels.vlen); SG_INFO("%d test examples\n", num_lab) - CSignal::clear_cancel(); + distance->precompute_lhs(); @@ -262,7 +262,7 @@ SGMatrix CKNN::classify_for_multiple_k() SGVector classes(m_num_classes); SG_INFO("%d test examples\n", num_lab) - CSignal::clear_cancel(); + init_solver(m_knn_solver); diff --git a/src/shogun/multiclass/LaRank.cpp b/src/shogun/multiclass/LaRank.cpp index 51e2e800d99..dbafc9207ef 100644 --- a/src/shogun/multiclass/LaRank.cpp +++ b/src/shogun/multiclass/LaRank.cpp @@ -619,7 +619,7 @@ bool CLaRank::train_machine(CFeatures* data) ASSERT(m_labels && m_labels->get_num_labels()) ASSERT(m_labels->get_label_type() == LT_MULTICLASS) - CSignal::clear_cancel(); + if (data) { diff --git a/src/shogun/multiclass/MulticlassLibLinear.cpp b/src/shogun/multiclass/MulticlassLibLinear.cpp index 6d56974c461..72df37460f3 100644 --- a/src/shogun/multiclass/MulticlassLibLinear.cpp +++ b/src/shogun/multiclass/MulticlassLibLinear.cpp @@ -121,7 +121,7 @@ bool CMulticlassLibLinear::train_machine(CFeatures* data) for (int32_t i=0; i +#include +#include + +#include + +using namespace shogun; +using namespace rxcpp; + + +TEST(Signal, SIGINT_test) { + + CSignal tmp; + int on_next_v=0; + int on_complete_v=0; + auto sub = rxcpp::make_subscriber( + [&on_next_v](int v) {on_next_v=1;}, [&]() {on_complete_v=1;} + ); + + tmp.get_SIGINT_observable().subscribe(sub); + + tmp.handler(SIGINT); + + EXPECT_TRUE(on_next_v == 0); + EXPECT_TRUE(on_complete_v == 1); +} + +TEST(Signal, SIGURG_test) { + + CSignal tmp; + int on_next_v=0; + int on_complete_v=0; + auto sub = rxcpp::make_subscriber( + [&](int v) {on_next_v++;}, [&]() {on_complete_v++;} + ); + + tmp.get_SIGURG_observable().subscribe(sub); + + tmp.handler(SIGURG); + + EXPECT_TRUE(on_next_v == 1); + EXPECT_TRUE(on_complete_v == 0); +}