diff --git a/src/shogun/classifier/Perceptron.cpp b/src/shogun/classifier/Perceptron.cpp index ffc02df0d1f..854ede85c29 100644 --- a/src/shogun/classifier/Perceptron.cpp +++ b/src/shogun/classifier/Perceptron.cpp @@ -35,7 +35,9 @@ void CPerceptron::init() max_iter = 1000; learn_rate = 0.1; m_initialize_hyperplane = true; - SG_ADD(&max_iter, "initialize_hyperplane", "Whether to initialize hyperplane.", MS_AVAILABLE); + SG_ADD( + &m_initialize_hyperplane, "initialize_hyperplane", + "Whether to initialize hyperplane.", MS_AVAILABLE); SG_ADD(&max_iter, "max_iter", "Maximum number of iterations.", MS_AVAILABLE); SG_ADD(&learn_rate, "learn_rate", "Learning rate.", MS_AVAILABLE); } @@ -65,15 +67,20 @@ bool CPerceptron::train_machine(CFeatures* data) ASSERT(num_vec==train_labels.vlen) SGVector output(num_vec); - SGVector w = get_w(); + SGVector w; if (m_initialize_hyperplane) { w = SGVector(num_feat); + set_w(w); + //start with uniform w, bias=0 + w.set_const(1.0 / num_feat); bias=0; - linalg::add_scalar(w, 1.0 / num_feat); } - + else + { + w = get_w(); + } //loop till we either get everything classified right or reach max_iter while (!converged && iter < max_iter) { @@ -105,8 +112,6 @@ bool CPerceptron::train_machine(CFeatures* data) else SG_WARNING("Perceptron algorithm did not converge after %d iterations.\n", max_iter) - set_w(w); - return converged; } diff --git a/tests/unit/classifier/Perceptron_unittest.cc b/tests/unit/classifier/Perceptron_unittest.cc index 4092a9608f4..16b99271efc 100644 --- a/tests/unit/classifier/Perceptron_unittest.cc +++ b/tests/unit/classifier/Perceptron_unittest.cc @@ -59,3 +59,25 @@ TEST(Perceptron, train) auto acc = some(); EXPECT_EQ(acc->evaluate(results, test_labels), 1.0); } + +TEST(Perceptron, custom_hyperplane_initialization) +{ + auto env = linear_test_env->getBinaryLabelData(); + auto features = wrap(env->get_features_train()); + auto labels = wrap(env->get_labels_train()); + auto test_features = wrap(env->get_features_test()); + auto test_labels = wrap(env->get_labels_test()); + + auto perceptron = some(features, labels); + perceptron->train(); + + auto weights = perceptron->get_w(); + + auto perceptron_initialized = some(features, labels); + perceptron_initialized->set_initialize_hyperplane(false); + perceptron_initialized->set_w(weights); + perceptron_initialized->set_max_iter(1); + + perceptron_initialized->train(); + EXPECT_TRUE(perceptron_initialized->get_w().equals(weights)); +}