Skip to content

Commit

Permalink
Merge pull request #4277 from shubham808/feature/refactor_labels_factory
Browse files Browse the repository at this point in the history
revome methods from labels factory
  • Loading branch information
karlnapf committed May 10, 2018
2 parents 507a704 + 0e603b4 commit bf67562
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 94 deletions.
44 changes: 0 additions & 44 deletions src/shogun/labels/LabelsFactory.cpp
Expand Up @@ -10,50 +10,6 @@

using namespace shogun;

CBinaryLabels* CLabelsFactory::to_binary(CLabels* base_labels)
{
ASSERT(base_labels != NULL)
if (base_labels->get_label_type() == LT_BINARY)
return static_cast<CBinaryLabels*>(base_labels);
else
SG_SERROR("base_labels must be of dynamic type CBinaryLabels")

return NULL;
}

CLatentLabels* CLabelsFactory::to_latent(CLabels* base_labels)
{
ASSERT(base_labels != NULL)
if (base_labels->get_label_type() == LT_LATENT)
return static_cast<CLatentLabels*>(base_labels);
else
SG_SERROR("base_labels must be of dynamic type CLatentLabels\n")

return NULL;
}

CMulticlassLabels* CLabelsFactory::to_multiclass(CLabels* base_labels)
{
ASSERT(base_labels != NULL)
if (base_labels->get_label_type() == LT_MULTICLASS)
return static_cast<CMulticlassLabels*>(base_labels);
else
SG_SERROR("base_labels must be of dynamic type CMulticlassLabels\n")

return NULL;
}

CRegressionLabels* CLabelsFactory::to_regression(CLabels* base_labels)
{
ASSERT(base_labels != NULL)
if (base_labels->get_label_type() == LT_REGRESSION)
return static_cast<CRegressionLabels*>(base_labels);
else
SG_SERROR("base_labels must be of dynamic type CRegressionLabels")

return NULL;
}

CStructuredLabels* CLabelsFactory::to_structured(CLabels* base_labels)
{
ASSERT(base_labels != NULL)
Expand Down
23 changes: 0 additions & 23 deletions src/shogun/labels/LabelsFactory.h
Expand Up @@ -29,29 +29,6 @@ namespace shogun
class CLabelsFactory : public CSGObject
{
public:
/** specialize a base class instance to CBinaryLabels
*
* @param base_labels its dynamic type must be CBinaryLabels
*/
static CBinaryLabels* to_binary(CLabels* base_labels);

/** specialize a base class instance to CLatentLabels
*
* @param base_labels its dynamic type must be CLatentLabels
*/
static CLatentLabels* to_latent(CLabels* base_labels);

/** specialize a base class instance to CMulticlassLabels
*
* @param base_labels its dynamic type must be CMulticlassLabels
*/
static CMulticlassLabels* to_multiclass(CLabels* base_labels);

/** specialize a base class instance to CRegressionLabels
*
* @param base_labels its dynamic type must be CRegressionLabels
*/
static CRegressionLabels* to_regression(CLabels* base_labels);

/** specialize a base class instance to CStructuredLabels
*
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/latent/LatentModel.cpp
Expand Up @@ -64,7 +64,7 @@ void CLatentModel::set_features(CLatentFeatures* feats)
void CLatentModel::argmax_h(const SGVector<float64_t>& w)
{
int32_t num = get_num_vectors();
CBinaryLabels* y = CLabelsFactory::to_binary(m_labels->get_labels());
CBinaryLabels* y = binary_labels(m_labels->get_labels());
ASSERT(num > 0)
ASSERT(num == m_labels->get_num_labels())

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/metric/LMNN.cpp
Expand Up @@ -60,7 +60,7 @@ void CLMNN::train(SGMatrix<float64_t> init_transform)

// cast is safe, check_training_setup ensures features are dense
CDenseFeatures<float64_t>* x = static_cast<CDenseFeatures<float64_t>*>(m_features);
CMulticlassLabels* y = CLabelsFactory::to_multiclass(m_labels);
CMulticlassLabels* y = multiclass_labels(m_labels);
SG_DEBUG("%d input vectors with %d dimensions.\n", x->get_num_vectors(), x->get_num_features());

auto& L = init_transform;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/metric/LMNNImpl.cpp
Expand Up @@ -66,7 +66,7 @@ void CLMNNImpl::check_training_setup(

void CLMNNImpl::check_maximum_k(CLabels* labels, int32_t k)
{
CMulticlassLabels* y = CLabelsFactory::to_multiclass(labels);
CMulticlassLabels* y = multiclass_labels(labels);
SGVector<int32_t> int_labels = y->get_int_labels();

// back-up initial values because they will be overwritten by unique
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/multiclass/tree/CHAIDTree.cpp
Expand Up @@ -100,14 +100,14 @@ CMulticlassLabels* CCHAIDTree::apply_multiclass(CFeatures* data)
{
REQUIRE(data, "Data required for classification in apply_multiclass\n")

return CLabelsFactory::to_multiclass(apply_tree(data));
return apply_tree(data)->as<CMulticlassLabels>();
}

CRegressionLabels* CCHAIDTree::apply_regression(CFeatures* data)
{
REQUIRE(data, "Data required for regression in apply_regression\n")

return CLabelsFactory::to_regression(apply_tree(data));
return apply_tree(data)->as<CRegressionLabels>();
}

void CCHAIDTree::set_weights(SGVector<float64_t> w)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/base/Serialization_unittest.cc
Expand Up @@ -99,7 +99,7 @@ TEST(Serialization, liblinear)
liblin->set_epsilon(1e-5);
liblin->train();

CBinaryLabels* pred = CLabelsFactory::to_binary(liblin->apply(test_feats));
CBinaryLabels* pred = liblin->apply(test_feats)->as<CBinaryLabels>();
for (int i = 0; i < num_samples; ++i)
EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i));
SG_UNREF(pred);
Expand All @@ -119,7 +119,7 @@ TEST(Serialization, liblinear)
SG_UNREF(file);

/* classify with the deserialized model */
pred = CLabelsFactory::to_binary(liblin_loaded->apply(test_feats));
pred = liblin_loaded->apply(test_feats)->as<CBinaryLabels>();
for (int i = 0; i < num_samples; ++i)
EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i));

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/clustering/kmeans_unittest.cc
Expand Up @@ -43,7 +43,8 @@ TEST(KMeans, manual_center_initialization_test)
{
clustering->train(features);

CMulticlassLabels* result=CLabelsFactory::to_multiclass(clustering->apply());
CMulticlassLabels* result =
clustering->apply()->as<CMulticlassLabels>();

EXPECT_EQ(0.000000, result->get_label(0));
EXPECT_EQ(0.000000, result->get_label(1));
Expand Down
29 changes: 16 additions & 13 deletions tests/unit/multiclass/KNN_unittest.cc
Expand Up @@ -39,23 +39,24 @@ TEST(KNN, brute_solver)
test.random(0, classes*num-1);

CMulticlassLabels* labels = new CMulticlassLabels(lab);

CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feat);
CFeatures* features_test = (CFeatures*) features->clone();
CFeatures* features_test = (CFeatures*)features->clone();
CLabels* labels_test = (CLabels*) labels->clone();

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_BRUTE);
SG_REF(knn);

features->add_subset(train);
labels->add_subset(train);
labels->add_subset(train);
knn->train(features);

features_test->add_subset(test);
labels_test->add_subset(test);
CMulticlassLabels* output=CLabelsFactory::to_multiclass(knn->apply(features_test));
CMulticlassLabels* output =
knn->apply(features_test)->as<CMulticlassLabels>();
SG_REF(output);
features_test->remove_subset();

Expand Down Expand Up @@ -86,21 +87,22 @@ TEST(KNN, kdtree_solver)

CMulticlassLabels* labels = new CMulticlassLabels(lab);
CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feat);
CFeatures* features_test = (CFeatures*) features->clone();
CFeatures* features_test = (CFeatures*)features->clone();
CLabels* labels_test = (CLabels*) labels->clone();

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_KDTREE);
SG_REF(knn);

features->add_subset(train);
labels->add_subset(train);
labels->add_subset(train);
knn->train(features);

features_test->add_subset(test);
labels_test->add_subset(test);
CMulticlassLabels* output=CLabelsFactory::to_multiclass(knn->apply(features_test));
CMulticlassLabels* output =
knn->apply(features_test)->as<CMulticlassLabels>();
SG_REF(output);
features_test->remove_subset();

Expand Down Expand Up @@ -132,21 +134,22 @@ TEST(KNN, lsh_solver)

CMulticlassLabels* labels = new CMulticlassLabels(lab);
CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feat);
CFeatures* features_test = (CFeatures*) features->clone();
CFeatures* features_test = (CFeatures*)features->clone();
CLabels* labels_test = (CLabels*) labels->clone();

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_LSH);
SG_REF(knn);

features->add_subset(train);
labels->add_subset(train);
labels->add_subset(train);
knn->train(features);

features_test->add_subset(test);
labels_test->add_subset(test);
CMulticlassLabels* output=CLabelsFactory::to_multiclass(knn->apply(features_test));
CMulticlassLabels* output =
knn->apply(features_test)->as<CMulticlassLabels>();
SG_REF(output);
features_test->remove_subset();

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/multiclass/MCLDA_unittest.cc
Expand Up @@ -28,7 +28,7 @@ TEST(MCLDA, train_and_apply)
SG_REF(lda);
lda->train();

CMulticlassLabels* output=CLabelsFactory::to_multiclass(lda->apply());
CMulticlassLabels* output = lda->apply()->as<CMulticlassLabels>();
SG_REF(output);

// Test
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/multiclass/QDA_unittest.cc
Expand Up @@ -28,7 +28,7 @@ TEST(QDA, train_and_apply)
SG_REF(qda);
qda->train();

CMulticlassLabels* output=CLabelsFactory::to_multiclass(qda->apply());
CMulticlassLabels* output = qda->apply()->as<CMulticlassLabels>();
SG_REF(output);

// Test
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/regression/LibSVR_unittest.cc
Expand Up @@ -58,8 +58,8 @@ TEST(LibSVR,epsilon_svr_apply)
svm->train();

/* predict */
CRegressionLabels* predicted_labels=CLabelsFactory::to_regression(
svm->apply(features_test));
CRegressionLabels* predicted_labels =
svm->apply(features_test)->as<CRegressionLabels>();

/* LibSVM regression comparison (with easy.py script) */
EXPECT_NEAR(predicted_labels->get_labels()[0], 2.44343, 1E-5);
Expand Down Expand Up @@ -123,8 +123,8 @@ TEST(LibSVR,nu_svr_apply)
svm->train();

/* predict */
CRegressionLabels* predicted_labels=CLabelsFactory::to_regression(
svm->apply(features_test));
CRegressionLabels* predicted_labels =
svm->apply(features_test)->as<CRegressionLabels>();

/* LibSVM regression comparison (with easy.py script) */
EXPECT_NEAR(predicted_labels->get_labels()[0], 2.18062, 1E-5);
Expand Down

0 comments on commit bf67562

Please sign in to comment.