Skip to content

Commit

Permalink
m_continue_features and iterative machine tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham808 committed Nov 18, 2018
1 parent ed31a10 commit cf9737a
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/shogun/classifier/Perceptron.cpp
Expand Up @@ -39,6 +39,10 @@ void CPerceptron::init_model(CFeatures* data)
if (!data->has_property(FP_DOT))
SG_ERROR("Specified features are not of type CDotFeatures\n")
set_features((CDotFeatures*) data);

SG_REF(data);
SG_UNREF(m_continue_features);
m_continue_features = data->as<CDotFeatures>();
}

int32_t num_feat = features->get_dim_feature_space();
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/classifier/svm/NewtonSVM.cpp
Expand Up @@ -54,6 +54,10 @@ void CNewtonSVM::init_model(CFeatures* data)
if (!data->has_property(FP_DOT))
SG_ERROR("Specified features are not of type CDotFeatures\n")
set_features((CDotFeatures*) data);

SG_REF(data);
SG_UNREF(m_continue_features);
m_continue_features = data->as<CDotFeatures>();
}

ASSERT(features)
Expand Down
11 changes: 8 additions & 3 deletions src/shogun/machine/IterativeMachine.h
Expand Up @@ -32,7 +32,8 @@ namespace shogun
{
m_current_iteration = 0;
m_complete = false;

m_continue_features = NULL;

SG_ADD(
&m_current_iteration, "current_iteration",
"Current Iteration of training", MS_NOT_AVAILABLE);
Expand All @@ -46,6 +47,7 @@ namespace shogun

virtual ~CIterativeMachine()
{
SG_UNREF(m_continue_features);
}

/** Returns convergence status */
Expand All @@ -57,7 +59,8 @@ namespace shogun
virtual bool continue_train()
{
this->reset_computation_variables();

this->set_features(m_continue_features);

auto pb = SG_PROGRESS(range(m_max_iterations));
while (m_current_iteration < m_max_iterations && !m_complete)
{
Expand Down Expand Up @@ -112,7 +115,9 @@ namespace shogun
virtual void end_training()
{
}


/** Stores features to continue training */
CDotFeatures* m_continue_features;
/** Maximum Iterations */
int32_t m_max_iterations;
/** Current iteration of training loop */
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/classifier/NewtonSVM_unittest.cc
@@ -0,0 +1,87 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Shubham Shukla
*/
#include <functional>
#include <rxcpp/rx-lite.hpp>
#include <shogun/lib/Signal.h>

#include "environments/LinearTestEnvironment.h"
#include <gtest/gtest.h>
#include <shogun/base/some.h>
#include <shogun/classifier/svm/NewtonSVM.h>
#include <shogun/evaluation/ContingencyTableEvaluation.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/BinaryLabels.h>

using namespace shogun;

extern LinearTestEnvironment* linear_test_env;

TEST(NewtonSVM, continue_training_consistency)
{
auto env = linear_test_env->getBinaryLabelData();
auto features = wrap(env->get_features_train());
auto labels = wrap(env->get_labels_train());
auto test_features = wrap(env->get_features_test());
auto test_labels = wrap(env->get_labels_test());

auto svm = some<CNewtonSVM>();
svm->set_labels(labels);
svm->train(features);

// reference for completly trained model
auto results = svm->apply(test_features);

index_t iter = 0;

// prematurely stopped at 5 iterations
auto svm_stop = some<CNewtonSVM>();

std::function<bool()> callback = [&iter]() {
if (iter >= 5)
{
get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
return true;
}
iter++;
return false;
};
svm_stop->set_callback(callback);

svm_stop->set_labels(labels);
svm_stop->train(features);

// callback executes, model should not be converged
ASSERT(!svm_stop->is_complete());

// reference model for intermediate state
auto svm_one = some<CNewtonSVM>();
svm_one->set_labels(labels);

// trained only till 5 iterations
svm_one->put<int32_t>("max_iterations", 5);
svm_one->train(features);

auto results_one = svm_one->apply(test_features);
auto results_stop = svm_stop->apply(test_features);

// prematurely stopped model and intermediate reference model should be
// consistent
EXPECT_TRUE(results_one->equals(results_stop));

svm_stop->set_callback(nullptr);

// continue training until converged
svm_stop->continue_train();
auto results_complete = svm_stop->apply(test_features);

// compare model with completely trained reference
EXPECT_TRUE(results_complete->equals(results));

SG_UNREF(results_one);
SG_UNREF(results_stop);
SG_UNREF(results);
SG_UNREF(results_complete);
}
17 changes: 14 additions & 3 deletions tests/unit/classifier/Perceptron_unittest.cc
Expand Up @@ -101,7 +101,7 @@ TEST(Perceptron, continue_training_consistency)
perceptron->set_labels(labels);
perceptron->train(features);

auto results = perceptron->apply_binary(test_features);
auto results = perceptron->apply(test_features);

index_t iter = 0;

Expand All @@ -123,13 +123,24 @@ TEST(Perceptron, continue_training_consistency)

ASSERT(!perceptron_stop->is_complete());

auto perceptron_one = some<CPerceptron>();
perceptron_one->set_labels(labels);
perceptron_one->put<int32_t>("max_iterations", 1);
perceptron_one->train(features);

auto results_one = perceptron_one->apply(test_features);
auto results_stop = perceptron_stop->apply(test_features);
EXPECT_TRUE(results_one->equals(results_stop));

perceptron_stop->set_callback(nullptr);
perceptron_stop->continue_train();

auto results_complete = perceptron_stop->apply_binary(test_features);
auto results_complete = perceptron_stop->apply(test_features);

EXPECT_TRUE(results_complete->get_labels().equals(results->get_labels()));
EXPECT_TRUE(results_complete->equals(results));

SG_UNREF(results_one);
SG_UNREF(results_stop);
SG_UNREF(results);
SG_UNREF(results_complete);
}

0 comments on commit cf9737a

Please sign in to comment.