Skip to content

Commit

Permalink
Revert CombinedKernel changes and fix remaining errors
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Apr 15, 2018
1 parent afd4359 commit d2852ba
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 14 deletions.
Expand Up @@ -52,11 +52,10 @@ real stddev = result.get_real("std_dev")
#![get_fold_machine]
CrossValidationStorage obs = mkl_obs.get_observation(0)
CrossValidationFoldStorage fold = obs.get_fold(0)
MKLClassification machine = fold.get_trained_machine()
MKLClassification machine = MKLClassification:obtain_from_generic(fold.get_trained_machine())
#![get_fold_machine]

#![get_weights]
SGObject k = machine.get("kernel")
CombinedKernel ck = CombinedKernel:obtain_from_generic(k)
RealVector w = ck.get_subkernel_weights()
CombinedKernel k = CombinedKernel:obtain_from_generic(machine.get_kernel())
RealVector w = k.get_subkernel_weights()
#![get_weights]
5 changes: 0 additions & 5 deletions src/shogun/kernel/CombinedKernel.cpp
Expand Up @@ -907,11 +907,6 @@ CCombinedKernel* CCombinedKernel::obtain_from_generic(CKernel* kernel)
return (CCombinedKernel*)kernel;
}

CCombinedKernel* CCombinedKernel::obtain_from_generic(Some<CSGObject> object)
{
return CCombinedKernel::obtain_from_generic((CKernel*)object.get());
}

CList* CCombinedKernel::combine_kernels(CList* kernel_list)
{
CList* return_list = new CList(true);
Expand Down
3 changes: 0 additions & 3 deletions src/shogun/kernel/CombinedKernel.h
Expand Up @@ -366,9 +366,6 @@ class CCombinedKernel : public CKernel
*/
static CCombinedKernel* obtain_from_generic(CKernel* kernel);

// TODO: remove
static CCombinedKernel* obtain_from_generic(Some<CSGObject> object);

/** return derivative with respect to specified parameter
*
* @param param the parameter
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/labels/MulticlassLabels_unittest.cc
Expand Up @@ -93,11 +93,11 @@ TEST_F(MulticlassLabels, multiclass_labels_from_dense_not_contiguous)
auto labels = some<CDenseLabels>(labels_true.size());
labels->set_labels({0, 1, 3});
auto converted = multiclass_labels(labels);
ASSERT_NE(converted, nullptr);
ASSERT_NE(converted.get(), nullptr);
EXPECT_TRUE(converted->get_labels().equals({0, 1, 2}));

labels->set_labels({-1, 1, 1});
auto converted2 = multiclass_labels(labels);
ASSERT_NE(converted2, nullptr);
ASSERT_NE(converted2.get(), nullptr);
EXPECT_TRUE(converted2->get_labels().equals({0, 1, 1}));
}

0 comments on commit d2852ba

Please sign in to comment.