From f14235704a0dd7c98b4569daae771378128efdd1 Mon Sep 17 00:00:00 2001 From: Shubham Shukla Date: Fri, 1 Jun 2018 16:00:21 +0530 Subject: [PATCH] connect Kernel and Linear machine to a signal handler (#4287) * connect Kernel and Linear machine to a signal handler * using base class methods --- src/shogun/machine/KernelMachine.cpp | 16 +++++---------- src/shogun/machine/KernelMachine.h | 9 +++++---- src/shogun/machine/LinearMachine.cpp | 29 +++++++--------------------- src/shogun/machine/Machine.cpp | 25 ++++++++++++++++++------ src/shogun/machine/Machine.h | 8 ++++++++ 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/src/shogun/machine/KernelMachine.cpp b/src/shogun/machine/KernelMachine.cpp index 9ac60a23738..c82ab121772 100644 --- a/src/shogun/machine/KernelMachine.cpp +++ b/src/shogun/machine/KernelMachine.cpp @@ -6,14 +6,14 @@ * Fernando Iglesias, Thoralf Klein */ +#include #include #include -#include -#include - -#include #include +#include #include +#include +#include #ifdef HAVE_OPENMP #include @@ -421,10 +421,6 @@ void CKernelMachine::store_model_features() bool CKernelMachine::train_locked(SGVector indices) { - SG_DEBUG("entering %s::train_locked()\n", get_name()) - if (!is_data_locked()) - SG_ERROR("CKernelMachine::train_locked() call data_lock() before!\n") - /* this is asusmed here */ ASSERT(m_custom_kernel==kernel) @@ -443,15 +439,13 @@ bool CKernelMachine::train_locked(SGVector indices) /* dont do train because model should not be stored (no acutal features) * and train does data_unlock */ - bool result=train_machine(); - + bool result = CMachine::train_locked(); /* remove last col subset of custom kernel */ m_custom_kernel->remove_col_subset(); /* remove label subset after training */ m_labels->remove_subset(); - SG_DEBUG("leaving %s::train_locked()\n", get_name()) return result; } diff --git a/src/shogun/machine/KernelMachine.h b/src/shogun/machine/KernelMachine.h index 2129852c1e9..499cc7def27 100644 --- a/src/shogun/machine/KernelMachine.h +++ b/src/shogun/machine/KernelMachine.h @@ -227,11 +227,12 @@ class CKernelMachine : public CMachine virtual float64_t apply_one(int32_t num); #ifndef SWIG // SWIG should skip this part - /** Trains a locked machine on a set of indices. Error if machine is - * not locked + + /** This precomputes the kernel matrix and stores it * - * @param indices index vector (of locked features) that is used for training - * @return whether training was successful + * @param indices index vector (of locked features) that is used for ++ * training + * @return whether training was successful */ virtual bool train_locked(SGVector indices); diff --git a/src/shogun/machine/LinearMachine.cpp b/src/shogun/machine/LinearMachine.cpp index b4efdf72c0a..2bb5edca005 100644 --- a/src/shogun/machine/LinearMachine.cpp +++ b/src/shogun/machine/LinearMachine.cpp @@ -6,10 +6,11 @@ * Fernando Iglesias */ -#include -#include +#include #include #include +#include +#include #include using namespace shogun; @@ -166,25 +167,9 @@ void CLinearMachine::compute_bias(CFeatures* data) bool CLinearMachine::train(CFeatures* data) { - /* not allowed to train on locked data */ - if (m_data_locked) - { - SG_ERROR("train data_lock() was called, only train_locked() is" - " possible. Call data_unlock if you want to call train()\n", - get_name()); - } - - if (train_require_labels()) - { - REQUIRE(m_labels,"No labels given",this->get_name()); - - m_labels->ensure_valid(get_name()); - } - - bool result = train_machine(data); - - if(m_compute_bias) - compute_bias(data); + bool result = CMachine::train(data); + if (m_compute_bias) + compute_bias(data); - return result; + return result; } diff --git a/src/shogun/machine/Machine.cpp b/src/shogun/machine/Machine.cpp index 321baf6682c..8e6f647c231 100644 --- a/src/shogun/machine/Machine.cpp +++ b/src/shogun/machine/Machine.cpp @@ -40,12 +40,11 @@ CMachine::~CMachine() bool CMachine::train(CFeatures* data) { /* not allowed to train on locked data */ - if (m_data_locked) - { - SG_ERROR("%s::train data_lock() was called, only train_locked() is" - " possible. Call data_unlock if you want to call train()\n", - get_name()); - } + REQUIRE( + !m_data_locked, "(%s)::train data_lock() was called, only " + "train_locked() is possible. Call data_unlock if you " + "want to call train()\n", + get_name()); if (train_require_labels()) { @@ -66,6 +65,20 @@ bool CMachine::train(CFeatures* data) return result; } +bool CMachine::train_locked() +{ + /*train machine without any actual features(data is locked)*/ + REQUIRE( + is_data_locked(), + "Data needs to be locked for training, call data_lock()\n") + + auto sub = connect_to_signal_handler(); + bool result = train_machine(); + sub.unsubscribe(); + reset_computation_variables(); + return result; +} + void CMachine::set_labels(CLabels* lab) { if (lab != NULL) diff --git a/src/shogun/machine/Machine.h b/src/shogun/machine/Machine.h index 5b945324e83..e6787f694b1 100644 --- a/src/shogun/machine/Machine.h +++ b/src/shogun/machine/Machine.h @@ -159,6 +159,14 @@ class CMachine : public CStoppableSGObject */ virtual bool train(CFeatures* data=NULL); +#ifndef SWIG // SWIG should skip this part + /** Trains a locked machine on a set of indices. Error if machine is + * not locked + * @return whether training was successful + */ + virtual bool train_locked(); +#endif + /** apply machine to data * if data is not specified apply to the current features *