Skip to content

Commit

Permalink
Clean up error messages in KNN (#4143)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and karlnapf committed Feb 4, 2018
1 parent cf7ba2c commit 93d06f8
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/shogun/multiclass/KNN.cpp
Expand Up @@ -33,8 +33,8 @@ CKNN::CKNN(int32_t k, CDistance* d, CLabels* trainlab, KNN_SOLVER knn_solver)

m_k=k;

ASSERT(d)
ASSERT(trainlab)
REQUIRE(d, "Distance not set.\n");
REQUIRE(trainlab, "Training labels not set.\n");

set_distance(d);
set_labels(trainlab);
Expand Down Expand Up @@ -74,19 +74,22 @@ CKNN::~CKNN()

bool CKNN::train_machine(CFeatures* data)
{
ASSERT(m_labels)
ASSERT(distance)
REQUIRE(m_labels, "No training labels provided.\n");
REQUIRE(distance, "No training distance provided.\n");

if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of labels\n")
REQUIRE(
m_labels->get_num_labels() == data->get_num_vectors(),
"Number of training vectors (%d) does not match number of labels "
"(%d)\n",
data->get_num_vectors(), m_labels->get_num_labels());
distance->init(data, data);
}

SGVector<int32_t> lab=((CMulticlassLabels*) m_labels)->get_int_labels();
m_train_labels=lab.clone();
ASSERT(m_train_labels.vlen>0)
REQUIRE(m_train_labels.vlen > 0, "Provided training labels are empty\n");

// find minimal and maximal class
auto min_class = CMath::min(m_train_labels.vector, m_train_labels.vlen);
Expand Down Expand Up @@ -161,9 +164,9 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)
if (m_k == 1)
return classify_NN();

ASSERT(m_num_classes>0)
ASSERT(distance)
ASSERT(distance->get_num_vec_rhs())
REQUIRE(m_num_classes > 0, "Machine not trained.\n");
REQUIRE(distance, "Distance not set.\n");
REQUIRE(distance->get_num_vec_rhs(), "No vectors on right hand side.\n");

int32_t num_lab=distance->get_num_vec_rhs();
ASSERT(m_k<=distance->get_num_vec_lhs())
Expand All @@ -187,11 +190,11 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)

CMulticlassLabels* CKNN::classify_NN()
{
ASSERT(distance)
ASSERT(m_num_classes>0)
REQUIRE(distance, "Distance not set.\n");
REQUIRE(m_num_classes > 0, "Machine not trained.\n");

int32_t num_lab = distance->get_num_vec_rhs();
ASSERT(num_lab)
REQUIRE(num_lab, "No vectors on right hand side\n");

CMulticlassLabels* output = new CMulticlassLabels(num_lab);
SGVector<float64_t> distances(m_train_labels.vlen);
Expand Down Expand Up @@ -237,12 +240,15 @@ CMulticlassLabels* CKNN::classify_NN()

SGMatrix<int32_t> CKNN::classify_for_multiple_k()
{
ASSERT(m_num_classes>0)
ASSERT(distance)
ASSERT(distance->get_num_vec_rhs())
REQUIRE(distance, "Distance not set.\n");
REQUIRE(m_num_classes > 0, "Machine not trained.\n");

int32_t num_lab=distance->get_num_vec_rhs();
ASSERT(m_k<=num_lab)
REQUIRE(num_lab, "No vectors on right hand side\n");

REQUIRE(
m_k <= num_lab, "Number of labels (%d) must be at least K (%d).\n",
num_lab, m_k);

//working buffer of m_train_labels
SGVector<int32_t> train_lab(m_k);
Expand All @@ -263,8 +269,7 @@ SGMatrix<int32_t> CKNN::classify_for_multiple_k()

void CKNN::init_distance(CFeatures* data)
{
if (!distance)
SG_ERROR("No distance assigned!\n")
REQUIRE(distance, "Distance not set.\n");
CFeatures* lhs=distance->get_lhs();
if (!lhs || !lhs->get_num_vectors())
{
Expand Down

0 comments on commit 93d06f8

Please sign in to comment.