Skip to content

Commit

Permalink
state update in each iteration for perceptron (#4320)
Browse files Browse the repository at this point in the history
* state update in each iteration for perceptron
* test for custom init
  • Loading branch information
shubham808 authored and karlnapf committed Jun 8, 2018
1 parent 1878190 commit cb39003
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/shogun/classifier/Perceptron.cpp
Expand Up @@ -35,7 +35,9 @@ void CPerceptron::init()
max_iter = 1000;
learn_rate = 0.1;
m_initialize_hyperplane = true;
SG_ADD(&max_iter, "initialize_hyperplane", "Whether to initialize hyperplane.", MS_AVAILABLE);
SG_ADD(
&m_initialize_hyperplane, "initialize_hyperplane",
"Whether to initialize hyperplane.", MS_AVAILABLE);
SG_ADD(&max_iter, "max_iter", "Maximum number of iterations.", MS_AVAILABLE);
SG_ADD(&learn_rate, "learn_rate", "Learning rate.", MS_AVAILABLE);
}
Expand Down Expand Up @@ -65,15 +67,20 @@ bool CPerceptron::train_machine(CFeatures* data)
ASSERT(num_vec==train_labels.vlen)
SGVector<float64_t> output(num_vec);

SGVector<float64_t> w = get_w();
SGVector<float64_t> w;
if (m_initialize_hyperplane)
{
w = SGVector<float64_t>(num_feat);
set_w(w);

//start with uniform w, bias=0
w.set_const(1.0 / num_feat);
bias=0;
linalg::add_scalar(w, 1.0 / num_feat);
}

else
{
w = get_w();
}
//loop till we either get everything classified right or reach max_iter
while (!converged && iter < max_iter)
{
Expand Down Expand Up @@ -105,8 +112,6 @@ bool CPerceptron::train_machine(CFeatures* data)
else
SG_WARNING("Perceptron algorithm did not converge after %d iterations.\n", max_iter)

set_w(w);

return converged;
}

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/classifier/Perceptron_unittest.cc
Expand Up @@ -59,3 +59,25 @@ TEST(Perceptron, train)
auto acc = some<CAccuracyMeasure>();
EXPECT_EQ(acc->evaluate(results, test_labels), 1.0);
}

TEST(Perceptron, custom_hyperplane_initialization)
{
auto env = linear_test_env->getBinaryLabelData();
auto features = wrap(env->get_features_train());
auto labels = wrap(env->get_labels_train());
auto test_features = wrap(env->get_features_test());
auto test_labels = wrap(env->get_labels_test());

auto perceptron = some<CPerceptron>(features, labels);
perceptron->train();

auto weights = perceptron->get_w();

auto perceptron_initialized = some<CPerceptron>(features, labels);
perceptron_initialized->set_initialize_hyperplane(false);
perceptron_initialized->set_w(weights);
perceptron_initialized->set_max_iter(1);

perceptron_initialized->train();
EXPECT_TRUE(perceptron_initialized->get_w().equals(weights));
}

0 comments on commit cb39003

Please sign in to comment.