Skip to content

Commit

Permalink
connect Kernel and Linear machine to a signal handler (#4287)
Browse files Browse the repository at this point in the history
* connect Kernel and Linear machine to a signal handler
* using base class methods
  • Loading branch information
shubham808 authored and karlnapf committed Jun 1, 2018
1 parent d8873e6 commit f142357
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 43 deletions.
16 changes: 5 additions & 11 deletions src/shogun/machine/KernelMachine.cpp
Expand Up @@ -6,14 +6,14 @@
* Fernando Iglesias, Thoralf Klein
*/

#include <rxcpp/rx-lite.hpp>
#include <shogun/base/progress.h>
#include <shogun/io/SGIO.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/KernelMachine.h>

#include <shogun/kernel/Kernel.h>
#include <shogun/kernel/CustomKernel.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/KernelMachine.h>

#ifdef HAVE_OPENMP
#include <omp.h>
Expand Down Expand Up @@ -421,10 +421,6 @@ void CKernelMachine::store_model_features()

bool CKernelMachine::train_locked(SGVector<index_t> indices)
{
SG_DEBUG("entering %s::train_locked()\n", get_name())
if (!is_data_locked())
SG_ERROR("CKernelMachine::train_locked() call data_lock() before!\n")

/* this is asusmed here */
ASSERT(m_custom_kernel==kernel)

Expand All @@ -443,15 +439,13 @@ bool CKernelMachine::train_locked(SGVector<index_t> indices)

/* dont do train because model should not be stored (no acutal features)
* and train does data_unlock */
bool result=train_machine();

bool result = CMachine::train_locked();
/* remove last col subset of custom kernel */
m_custom_kernel->remove_col_subset();

/* remove label subset after training */
m_labels->remove_subset();

SG_DEBUG("leaving %s::train_locked()\n", get_name())
return result;
}

Expand Down
9 changes: 5 additions & 4 deletions src/shogun/machine/KernelMachine.h
Expand Up @@ -227,11 +227,12 @@ class CKernelMachine : public CMachine
virtual float64_t apply_one(int32_t num);

#ifndef SWIG // SWIG should skip this part
/** Trains a locked machine on a set of indices. Error if machine is
* not locked

/** This precomputes the kernel matrix and stores it
*
* @param indices index vector (of locked features) that is used for training
* @return whether training was successful
* @param indices index vector (of locked features) that is used for
+ * training
* @return whether training was successful
*/
virtual bool train_locked(SGVector<index_t> indices);

Expand Down
29 changes: 7 additions & 22 deletions src/shogun/machine/LinearMachine.cpp
Expand Up @@ -6,10 +6,11 @@
* Fernando Iglesias
*/

#include <shogun/machine/LinearMachine.h>
#include <shogun/labels/RegressionLabels.h>
#include <rxcpp/rx-lite.hpp>
#include <shogun/features/DotFeatures.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/LinearMachine.h>
#include <shogun/mathematics/eigen3.h>

using namespace shogun;
Expand Down Expand Up @@ -166,25 +167,9 @@ void CLinearMachine::compute_bias(CFeatures* data)

bool CLinearMachine::train(CFeatures* data)
{
/* not allowed to train on locked data */
if (m_data_locked)
{
SG_ERROR("train data_lock() was called, only train_locked() is"
" possible. Call data_unlock if you want to call train()\n",
get_name());
}

if (train_require_labels())
{
REQUIRE(m_labels,"No labels given",this->get_name());

m_labels->ensure_valid(get_name());
}

bool result = train_machine(data);

if(m_compute_bias)
compute_bias(data);
bool result = CMachine::train(data);
if (m_compute_bias)
compute_bias(data);

return result;
return result;
}
25 changes: 19 additions & 6 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -40,12 +40,11 @@ CMachine::~CMachine()
bool CMachine::train(CFeatures* data)
{
/* not allowed to train on locked data */
if (m_data_locked)
{
SG_ERROR("%s::train data_lock() was called, only train_locked() is"
" possible. Call data_unlock if you want to call train()\n",
get_name());
}
REQUIRE(
!m_data_locked, "(%s)::train data_lock() was called, only "
"train_locked() is possible. Call data_unlock if you "
"want to call train()\n",
get_name());

if (train_require_labels())
{
Expand All @@ -66,6 +65,20 @@ bool CMachine::train(CFeatures* data)
return result;
}

bool CMachine::train_locked()
{
/*train machine without any actual features(data is locked)*/
REQUIRE(
is_data_locked(),
"Data needs to be locked for training, call data_lock()\n")

auto sub = connect_to_signal_handler();
bool result = train_machine();
sub.unsubscribe();
reset_computation_variables();
return result;
}

void CMachine::set_labels(CLabels* lab)
{
if (lab != NULL)
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -159,6 +159,14 @@ class CMachine : public CStoppableSGObject
*/
virtual bool train(CFeatures* data=NULL);

#ifndef SWIG // SWIG should skip this part
/** Trains a locked machine on a set of indices. Error if machine is
* not locked
* @return whether training was successful
*/
virtual bool train_locked();
#endif

/** apply machine to data
* if data is not specified apply to the current features
*
Expand Down

0 comments on commit f142357

Please sign in to comment.