Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

connect Kernel and Linear machine to a signal handler #4287

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 5 additions & 11 deletions src/shogun/machine/KernelMachine.cpp
Expand Up @@ -6,14 +6,14 @@
* Fernando Iglesias, Thoralf Klein
*/

#include <rxcpp/rx-lite.hpp>
#include <shogun/base/progress.h>
#include <shogun/io/SGIO.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/KernelMachine.h>

#include <shogun/kernel/Kernel.h>
#include <shogun/kernel/CustomKernel.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/KernelMachine.h>

#ifdef HAVE_OPENMP
#include <omp.h>
Expand Down Expand Up @@ -421,10 +421,6 @@ void CKernelMachine::store_model_features()

bool CKernelMachine::train_locked(SGVector<index_t> indices)
{
SG_DEBUG("entering %s::train_locked()\n", get_name())
if (!is_data_locked())
SG_ERROR("CKernelMachine::train_locked() call data_lock() before!\n")

/* this is asusmed here */
ASSERT(m_custom_kernel==kernel)

Expand All @@ -443,15 +439,13 @@ bool CKernelMachine::train_locked(SGVector<index_t> indices)

/* dont do train because model should not be stored (no acutal features)
* and train does data_unlock */
bool result=train_machine();

bool result = CMachine::train_locked();
/* remove last col subset of custom kernel */
m_custom_kernel->remove_col_subset();

/* remove label subset after training */
m_labels->remove_subset();

SG_DEBUG("leaving %s::train_locked()\n", get_name())
return result;
}

Expand Down
9 changes: 5 additions & 4 deletions src/shogun/machine/KernelMachine.h
Expand Up @@ -227,11 +227,12 @@ class CKernelMachine : public CMachine
virtual float64_t apply_one(int32_t num);

#ifndef SWIG // SWIG should skip this part
/** Trains a locked machine on a set of indices. Error if machine is
* not locked

/** This precomputes the kernel matrix and stores it
*
* @param indices index vector (of locked features) that is used for training
* @return whether training was successful
* @param indices index vector (of locked features) that is used for
+ * training
* @return whether training was successful
*/
virtual bool train_locked(SGVector<index_t> indices);

Expand Down
29 changes: 7 additions & 22 deletions src/shogun/machine/LinearMachine.cpp
Expand Up @@ -6,10 +6,11 @@
* Fernando Iglesias
*/

#include <shogun/machine/LinearMachine.h>
#include <shogun/labels/RegressionLabels.h>
#include <rxcpp/rx-lite.hpp>
#include <shogun/features/DotFeatures.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/LinearMachine.h>
#include <shogun/mathematics/eigen3.h>

using namespace shogun;
Expand Down Expand Up @@ -166,25 +167,9 @@ void CLinearMachine::compute_bias(CFeatures* data)

bool CLinearMachine::train(CFeatures* data)
{
/* not allowed to train on locked data */
if (m_data_locked)
{
SG_ERROR("train data_lock() was called, only train_locked() is"
" possible. Call data_unlock if you want to call train()\n",
get_name());
}

if (train_require_labels())
{
REQUIRE(m_labels,"No labels given",this->get_name());

m_labels->ensure_valid(get_name());
}

bool result = train_machine(data);

if(m_compute_bias)
compute_bias(data);
bool result = CMachine::train(data);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this change!

if (m_compute_bias)
compute_bias(data);

return result;
return result;
}
25 changes: 19 additions & 6 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -40,12 +40,11 @@ CMachine::~CMachine()
bool CMachine::train(CFeatures* data)
{
/* not allowed to train on locked data */
if (m_data_locked)
{
SG_ERROR("%s::train data_lock() was called, only train_locked() is"
" possible. Call data_unlock if you want to call train()\n",
get_name());
}
REQUIRE(
!m_data_locked, "(%s)::train data_lock() was called, only "
"train_locked() is possible. Call data_unlock if you "
"want to call train()\n",
get_name());

if (train_require_labels())
{
Expand All @@ -66,6 +65,20 @@ bool CMachine::train(CFeatures* data)
return result;
}

bool CMachine::train_locked()
{
/*train machine without any actual features(data is locked)*/
REQUIRE(
is_data_locked(),
"Data needs to be locked for training, call data_lock()\n")

auto sub = connect_to_signal_handler();
bool result = train_machine();
sub.unsubscribe();
reset_computation_variables();
return result;
}

void CMachine::set_labels(CLabels* lab)
{
if (lab != NULL)
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -159,6 +159,14 @@ class CMachine : public CStoppableSGObject
*/
virtual bool train(CFeatures* data=NULL);

#ifndef SWIG // SWIG should skip this part
/** Trains a locked machine on a set of indices. Error if machine is
* not locked
* @return whether training was successful
*/
virtual bool train_locked();
#endif

/** apply machine to data
* if data is not specified apply to the current features
*
Expand Down