Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cancel_computation to enable testing. #4293

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/shogun/lib/StoppableSGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ using namespace shogun;
continue; \
pause_computation();

CStoppableSGObject::CStoppableSGObject() : CSGObject(){};

CStoppableSGObject::~CStoppableSGObject(){};

void CStoppableSGObject::init_stoppable()
CStoppableSGObject::CStoppableSGObject() : CSGObject()
{
m_cancel_computation = false;
m_pause_computation_flag = false;
}

m_callback = nullptr;
};

CStoppableSGObject::~CStoppableSGObject(){};

rxcpp::subscription CStoppableSGObject::connect_to_signal_handler()
{
Expand All @@ -40,6 +40,11 @@ rxcpp::subscription CStoppableSGObject::connect_to_signal_handler()
return get_global_signal()->get_observable()->subscribe(subscriber);
}

void CStoppableSGObject::set_callback(std::function<bool()> callback)
{
m_callback = callback;
}

void CStoppableSGObject::reset_computation_variables()
{
m_cancel_computation = false;
Expand Down
17 changes: 13 additions & 4 deletions src/shogun/lib/StoppableSGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ namespace shogun
/** destructor */
virtual ~CStoppableSGObject();

/** init flags to false as default */
void init_stoppable();

#ifndef SWIG
/** @return whether the algorithm needs to be stopped */
SG_FORCED_INLINE bool cancel_computation() const
{
return m_cancel_computation.load();
/* Execute the callback, if present*/
return (m_callback) ? (m_cancel_computation.load() || m_callback())
: m_cancel_computation.load();
}
#endif

Expand All @@ -66,6 +65,13 @@ namespace shogun
}
#endif

/**
* Set an additional stopping condition
* @param callback method that implements an additional stopping
* condition
*/
void set_callback(std::function<bool()> callback);

virtual const char* get_name() const
{
return "StoppableSGObject";
Expand Down Expand Up @@ -97,6 +103,7 @@ namespace shogun
void on_complete();
virtual void on_complete_impl();

protected:
/** Cancel computation */
std::atomic<bool> m_cancel_computation;

Expand All @@ -108,6 +115,8 @@ namespace shogun

/** Mutex used to pause threads */
std::mutex m_mutex;

std::function<bool(void)> m_callback;
};
}
#endif
1 change: 0 additions & 1 deletion src/shogun/machine/Machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ CMachine::CMachine()
: CStoppableSGObject(), m_max_train_time(0), m_labels(NULL),
m_solver_type(ST_AUTO)
{
init_stoppable();
m_data_locked=false;
m_store_model_features=false;

Expand Down
108 changes: 108 additions & 0 deletions tests/unit/lib/StoppableSGObject_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Giovanni De Toni
*/

#include <functional>
#include <gtest/gtest.h>
#include <rxcpp/rx-lite.hpp>
#include <shogun/lib/Signal.h>
#include <shogun/machine/Machine.h>

using namespace shogun;
using namespace std;

/**
* Mock model to show the use of the callback.
*/
class MockModel : public CMachine
{
public:
MockModel() : m_check(0), m_last_iteration(0)
{
// Set up the custom callback
function<bool()> callback = [this]() {
// Stop if we did more than 5 steps
if (m_last_iteration >= 5)
{
get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
return true;
}
m_last_iteration++;
return false;
};

this->set_callback(callback);
};

int get_check()
{
return m_check;
}

virtual const char* get_name() const
{
return "MockModel";
}

protected:
virtual bool train_require_labels() const
{
return false;
}

/* Custom train machine */
virtual bool train_machine(CFeatures* data = NULL)
{
for (int num_iterations_train = 0; num_iterations_train < 10;
num_iterations_train++)
{
COMPUTATION_CONTROLLERS
m_check++;
}
return true;
}

/* Control variable, used to check that we stopped the training at the
* exact number of iterations (it will be equal to m_last_iteration)*/
int m_check;

/* Addition control variable that is incremented each time by the
* callback.*/
int m_last_iteration;
};

TEST(StoppableSGObject, empty_callback)
{
MockModel a;
a.set_callback(nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you either use the default (nullptr) valued set_callback, i.e. use

a.set_callback();

or remove the default value of the function arg. i would opt for the latter

a.train();
EXPECT_TRUE(a.get_check() == 10);
}

TEST(StoppableSGObject, default_callback)
{
MockModel a;
a.train();
EXPECT_TRUE(a.get_check() == 5);
}

TEST(StoppableSGObject, custom_callback_by_user)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test shows more clearly what we are going to do in the future when testing the premature stopping feature.

{
int i = 0;
function<bool()> callback = [&i]() {
if (i >= 3)
{
get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
return true;
}
i++;
return false;
};

MockModel a;
a.set_callback(callback);
a.train();
EXPECT_TRUE(a.get_check() == 3);
}