Skip to content

Commit

Permalink
Apply clang-format and other minor fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
geektoni authored and vigsterkr committed May 25, 2018
1 parent b85355c commit 040692b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
3 changes: 1 addition & 2 deletions src/shogun/lib/StoppableSGObject.cpp
Expand Up @@ -40,12 +40,11 @@ rxcpp::subscription CStoppableSGObject::connect_to_signal_handler()
return get_global_signal()->get_observable()->subscribe(subscriber);
}

void CStoppableSGObject::set_callback(std::function<bool()> callback=nullptr)
void CStoppableSGObject::set_callback(std::function<bool()> callback)
{
m_callback = callback;
}


void CStoppableSGObject::reset_computation_variables()
{
m_cancel_computation = false;
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/lib/StoppableSGObject.h
Expand Up @@ -37,7 +37,8 @@ namespace shogun
SG_FORCED_INLINE bool cancel_computation() const
{
/* Execute the callback, if present*/
return (m_callback) ? (m_cancel_computation.load() || m_callback()) : m_cancel_computation.load();
return (m_callback) ? (m_cancel_computation.load() || m_callback())
: m_cancel_computation.load();
}
#endif

Expand Down Expand Up @@ -66,7 +67,8 @@ namespace shogun

/**
* Set an additional stopping condition
* @param callback method that implements an additional stopping condition
* @param callback method that implements an additional stopping
* condition
*/
void set_callback(std::function<bool()> callback);

Expand All @@ -76,7 +78,6 @@ namespace shogun
}

protected:

/** connect the machine instance to the signal handler */
rxcpp::subscription connect_to_signal_handler();

Expand Down Expand Up @@ -116,7 +117,6 @@ namespace shogun
std::mutex m_mutex;

std::function<bool(void)> m_callback;

};
}
#endif
35 changes: 17 additions & 18 deletions tests/unit/lib/StoppableSGObject_unittest.cc
Expand Up @@ -6,9 +6,9 @@

#include <functional>
#include <gtest/gtest.h>
#include <shogun/machine/Machine.h>
#include <shogun/lib/Signal.h>
#include <rxcpp/rx-lite.hpp>
#include <shogun/lib/Signal.h>
#include <shogun/machine/Machine.h>

using namespace shogun;
using namespace std;
Expand All @@ -22,8 +22,7 @@ class MockModel : public CMachine
MockModel() : m_check(0), m_last_iteration(0)
{
// Set up the custom callback
function<bool()> callback = [this]()
{
function<bool()> callback = [this]() {
// Stop if we did more than 5 steps
if (m_last_iteration >= 5)
{
Expand All @@ -44,18 +43,20 @@ class MockModel : public CMachine

virtual const char* get_name() const
{
return "Mock_model";
return "MockModel";
}


protected:

virtual bool train_require_labels() const { return false; }
virtual bool train_require_labels() const
{
return false;
}

/* Custom train machine */
virtual bool train_machine(CFeatures* data=NULL)
virtual bool train_machine(CFeatures* data = NULL)
{
for (int num_iterations_train=0; num_iterations_train<10; num_iterations_train++)
for (int num_iterations_train = 0; num_iterations_train < 10;
num_iterations_train++)
{
COMPUTATION_CONTROLLERS
m_check++;
Expand All @@ -67,11 +68,11 @@ class MockModel : public CMachine
* exact number of iterations (it will be equal to m_last_iteration)*/
int m_check;

/* Addition control variable that is incremented each time by the callback.*/
/* Addition control variable that is incremented each time by the
* callback.*/
int m_last_iteration;
};


TEST(StoppableSGObject, empty_callback)
{
MockModel a;
Expand All @@ -89,21 +90,19 @@ TEST(StoppableSGObject, default_callback)

TEST(StoppableSGObject, custom_callback_by_user)
{
int i=0;
function<bool()> callback = [&i]()
{
if (i>=3) {
int i = 0;
function<bool()> callback = [&i]() {
if (i >= 3)
{
get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
return true;
}
i++;
return false;
};


MockModel a;
a.set_callback(callback);
a.train();
EXPECT_TRUE(a.get_check() == 3);
}

0 comments on commit 040692b

Please sign in to comment.