From 93d06f857b52e9988b3b9fa3a2249ee21e854155 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 4 Feb 2018 19:07:31 +0800 Subject: [PATCH] Clean up error messages in KNN (#4143) --- src/shogun/multiclass/KNN.cpp | 43 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/shogun/multiclass/KNN.cpp b/src/shogun/multiclass/KNN.cpp index e7688530f8a..52c80cc4e05 100644 --- a/src/shogun/multiclass/KNN.cpp +++ b/src/shogun/multiclass/KNN.cpp @@ -33,8 +33,8 @@ CKNN::CKNN(int32_t k, CDistance* d, CLabels* trainlab, KNN_SOLVER knn_solver) m_k=k; - ASSERT(d) - ASSERT(trainlab) + REQUIRE(d, "Distance not set.\n"); + REQUIRE(trainlab, "Training labels not set.\n"); set_distance(d); set_labels(trainlab); @@ -74,19 +74,22 @@ CKNN::~CKNN() bool CKNN::train_machine(CFeatures* data) { - ASSERT(m_labels) - ASSERT(distance) + REQUIRE(m_labels, "No training labels provided.\n"); + REQUIRE(distance, "No training distance provided.\n"); if (data) { - if (m_labels->get_num_labels() != data->get_num_vectors()) - SG_ERROR("Number of training vectors does not match number of labels\n") + REQUIRE( + m_labels->get_num_labels() == data->get_num_vectors(), + "Number of training vectors (%d) does not match number of labels " + "(%d)\n", + data->get_num_vectors(), m_labels->get_num_labels()); distance->init(data, data); } SGVector lab=((CMulticlassLabels*) m_labels)->get_int_labels(); m_train_labels=lab.clone(); - ASSERT(m_train_labels.vlen>0) + REQUIRE(m_train_labels.vlen > 0, "Provided training labels are empty\n"); // find minimal and maximal class auto min_class = CMath::min(m_train_labels.vector, m_train_labels.vlen); @@ -161,9 +164,9 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data) if (m_k == 1) return classify_NN(); - ASSERT(m_num_classes>0) - ASSERT(distance) - ASSERT(distance->get_num_vec_rhs()) + REQUIRE(m_num_classes > 0, "Machine not trained.\n"); + REQUIRE(distance, "Distance not set.\n"); + REQUIRE(distance->get_num_vec_rhs(), "No vectors on right hand side.\n"); int32_t num_lab=distance->get_num_vec_rhs(); ASSERT(m_k<=distance->get_num_vec_lhs()) @@ -187,11 +190,11 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data) CMulticlassLabels* CKNN::classify_NN() { - ASSERT(distance) - ASSERT(m_num_classes>0) + REQUIRE(distance, "Distance not set.\n"); + REQUIRE(m_num_classes > 0, "Machine not trained.\n"); int32_t num_lab = distance->get_num_vec_rhs(); - ASSERT(num_lab) + REQUIRE(num_lab, "No vectors on right hand side\n"); CMulticlassLabels* output = new CMulticlassLabels(num_lab); SGVector distances(m_train_labels.vlen); @@ -237,12 +240,15 @@ CMulticlassLabels* CKNN::classify_NN() SGMatrix CKNN::classify_for_multiple_k() { - ASSERT(m_num_classes>0) - ASSERT(distance) - ASSERT(distance->get_num_vec_rhs()) + REQUIRE(distance, "Distance not set.\n"); + REQUIRE(m_num_classes > 0, "Machine not trained.\n"); int32_t num_lab=distance->get_num_vec_rhs(); - ASSERT(m_k<=num_lab) + REQUIRE(num_lab, "No vectors on right hand side\n"); + + REQUIRE( + m_k <= num_lab, "Number of labels (%d) must be at least K (%d).\n", + num_lab, m_k); //working buffer of m_train_labels SGVector train_lab(m_k); @@ -263,8 +269,7 @@ SGMatrix CKNN::classify_for_multiple_k() void CKNN::init_distance(CFeatures* data) { - if (!distance) - SG_ERROR("No distance assigned!\n") + REQUIRE(distance, "Distance not set.\n"); CFeatures* lhs=distance->get_lhs(); if (!lhs || !lhs->get_num_vectors()) {