Skip to content

Commit

Permalink
Improve unit tests for StoppableSGObject.
Browse files Browse the repository at this point in the history
  • Loading branch information
geektoni authored and vigsterkr committed May 25, 2018
1 parent 7cab56e commit 8b44285
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions tests/unit/lib/StoppableSGObject_unittest.cc
Expand Up @@ -13,26 +13,28 @@
using namespace shogun;
using namespace std;

class Mock_model : public CMachine
/**
* Mock model to show the use of the callback.
*/
class MockModel : public CMachine
{
public:
Mock_model() : m_check(0), m_i(0)
MockModel() : m_check(0), m_last_iteration(0)
{
// Set up the custom callback
function<bool()> callback = [this]()
{
// Stop if we did more than 5 steps
if (m_i >= 5)
if (m_last_iteration >= 5)
{
get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
return true;
}
m_i++;
m_last_iteration++;
return false;
};

// We then add the callback
this->add_callback(callback);
this->set_callback(callback);
};

int get_check()
Expand All @@ -48,28 +50,39 @@ class Mock_model : public CMachine

protected:

/** returns whether machine require labels for training */
virtual bool train_require_labels() const { return false; }

// Custom train machine
/* Custom train machine */
virtual bool train_machine(CFeatures* data=NULL)
{
for (int k=0; k<10; k++)
for (int num_iterations_train=0; num_iterations_train<10; num_iterations_train++)
{
COMPUTATION_CONTROLLERS
m_check++;
}
return true;
}

/* Control variable, used to check that we stopped the training at the
* exact number of iterations (it will be equal to m_last_iteration)*/
int m_check;
int m_i;

/* Addition control variable that is incremented each time by the callback.*/
int m_last_iteration;
};


TEST(StoppableSGObject, empty_callback)
{
MockModel a;
a.set_callback(nullptr);
a.train();
EXPECT_TRUE(a.get_check() == 10);
}

TEST(StoppableSGObject, custom_callback)
TEST(StoppableSGObject, default_callback)
{
Mock_model a;
MockModel a;
a.train();
EXPECT_TRUE(a.get_check() == 5);
}
Expand All @@ -88,8 +101,8 @@ TEST(StoppableSGObject, custom_callback_by_user)
};


Mock_model a;
a.add_callback(callback);
MockModel a;
a.set_callback(callback);
a.train();
EXPECT_TRUE(a.get_check() == 3);
}
Expand Down

0 comments on commit 8b44285

Please sign in to comment.