diff --git a/src/shogun/kernel/CombinedKernel.cpp b/src/shogun/kernel/CombinedKernel.cpp index 9951fa03329..44aa245c1cf 100644 --- a/src/shogun/kernel/CombinedKernel.cpp +++ b/src/shogun/kernel/CombinedKernel.cpp @@ -13,7 +13,7 @@ #include #include #include - +#include #include #include #include @@ -925,3 +925,142 @@ CCombinedKernel* CCombinedKernel::obtain_from_generic(CKernel* kernel) SG_REF(kernel); return (CCombinedKernel*)kernel; } + +CList* CCombinedKernel::combine_kernels(CList* kernel_list) +{ + CList* return_list = new CList(true); + SG_REF(return_list); + + if (!kernel_list) + return return_list; + + if (kernel_list->get_num_elements()==0) + return return_list; + + int32_t num_combinations = 1; + int32_t list_index = 0; + + /* calculation of total combinations */ + CSGObject* list = kernel_list->get_first_element(); + while (list) + { + CList* c_list= dynamic_cast(list); + if (!c_list) + SG_SERROR("CCombinedKernel::combine_kernels() : Failed to cast list of type " + "%s to type CList\n", list->get_name()); + + if (c_list->get_num_elements()==0) + SG_SERROR("CCombinedKernel::combine_kernels() : Sub-list in position %d " + "is empty.\n", list_index); + + num_combinations *= c_list->get_num_elements(); + + if (kernel_list->get_delete_data()) + SG_UNREF(list); + + list = kernel_list->get_next_element(); + ++list_index; + } + + /* creation of CCombinedKernels */ + CDynamicObjectArray kernel_array(num_combinations); + for (index_t i=0; iappend_element(c_kernel); + kernel_array.push_back(c_kernel); + } + + /* first pass */ + list = kernel_list->get_first_element(); + CList* c_list = dynamic_cast(list); + + /* kernel index in the list */ + index_t kernel_index = 0; + + /* here we duplicate the first list in the following form + * a,b,c,d, a,b,c,d ...... a,b,c,d ---- for a total of num_combinations elements + */ + EKernelType prev_kernel_type; + bool first_kernel = true; + for (CSGObject* kernel=c_list->get_first_element(); kernel; kernel=c_list->get_next_element()) + { + CKernel* c_kernel = dynamic_cast(kernel); + + if (first_kernel) + first_kernel = false; + else if (c_kernel->get_kernel_type()!=prev_kernel_type) + SG_SERROR("CCombinedKernel::combine_kernels() : Sub-list in position " + "0 contains different types of kernels\n"); + + prev_kernel_type = c_kernel->get_kernel_type(); + + for (index_t index=kernel_index; indexget_num_elements()) + { + CCombinedKernel* comb_kernel = + dynamic_cast(kernel_array.get_element(index)); + comb_kernel->insert_kernel(c_kernel); + SG_UNREF(comb_kernel); + } + ++kernel_index; + if (c_list->get_delete_data()) + SG_UNREF(kernel); + } + + if (kernel_list->get_delete_data()) + SG_UNREF(list); + + /* how often each kernel of the sub-list must appear */ + int32_t freq = c_list->get_num_elements(); + + /* in this loop we replicate each kernel freq times + * until we assign to all the CombinedKernels a sub-kernel from this list + * That is for num_combinations */ + list = kernel_list->get_next_element(); + list_index = 1; + while (list) + { + c_list = dynamic_cast(list); + + /* index of kernel in the list */ + kernel_index = 0; + first_kernel = true; + for (CSGObject* kernel=c_list->get_first_element(); kernel; kernel=c_list->get_next_element()) + { + CKernel* c_kernel = dynamic_cast(kernel); + + if (first_kernel) + first_kernel = false; + else if (c_kernel->get_kernel_type()!=prev_kernel_type) + SG_SERROR("CCombinedKernel::combine_kernels() : Sub-list in position " + "%d contains different types of kernels\n", list_index); + + prev_kernel_type = c_kernel->get_kernel_type(); + + /* moves the index so that we keep filling in, the way we do, until we reach the end of the list of combinedkernels */ + for (index_t base=kernel_index*freq; baseget_num_elements()*freq) + { + /* inserts freq consecutives times the current kernel */ + for (index_t index=0; index(kernel_array.get_element(base+index)); + comb_kernel->insert_kernel(c_kernel); + SG_UNREF(comb_kernel); + } + } + ++kernel_index; + + if (c_list->get_delete_data()) + SG_UNREF(kernel); + } + + freq *= c_list->get_num_elements(); + if (kernel_list->get_delete_data()) + SG_UNREF(list); + list = kernel_list->get_next_element(); + ++list_index; + } + + return return_list; +} diff --git a/src/shogun/kernel/CombinedKernel.h b/src/shogun/kernel/CombinedKernel.h index f64462b1ed9..66f7b014cf5 100644 --- a/src/shogun/kernel/CombinedKernel.h +++ b/src/shogun/kernel/CombinedKernel.h @@ -402,6 +402,14 @@ class CCombinedKernel : public CKernel */ inline CList* get_list() {SG_REF(kernel_list); return kernel_list;} + /** Returns a list of all the different CombinedKernels produced by the cross-product between the kernel lists + * + * @param kernel_list a list of lists of kernels. Each sub-list must contain kernels of the same type + * + * @return a list of CombinedKernels. + */ + static CList* combine_kernels(CList* kernel_list); + protected: /** compute kernel function * diff --git a/tests/unit/kernel/CombinedKernel_unittest.cc b/tests/unit/kernel/CombinedKernel_unittest.cc index 9c4bff71305..cf9fbaed7f6 100644 --- a/tests/unit/kernel/CombinedKernel_unittest.cc +++ b/tests/unit/kernel/CombinedKernel_unittest.cc @@ -68,5 +68,114 @@ TEST(CombinedKernelTest,serialization) EXPECT_EQ(weights[0], w[0]); EXPECT_EQ(weights[1], w[1]); EXPECT_EQ(weights[2], w[2]); + SG_UNREF(combined_read); SG_UNREF(combined); } + +TEST(CombinedKernelTest,combination) +{ + CList* kernel_list = 0; + + CList* combined_list = CCombinedKernel::combine_kernels(kernel_list); + EXPECT_EQ(combined_list->get_num_elements(),0); + SG_UNREF(combined_list); + + kernel_list = new CList(true); + combined_list = CCombinedKernel::combine_kernels(kernel_list); + EXPECT_EQ(combined_list->get_num_elements(),0); + SG_UNREF(combined_list); + + CList* sub_list_1 = new CList(true); + CGaussianKernel* ck1 = new CGaussianKernel(10,3); + sub_list_1->append_element(ck1); + CGaussianKernel* ck2 = new CGaussianKernel(10,5); + sub_list_1->append_element(ck2); + CGaussianKernel* ck3 = new CGaussianKernel(10,7); + sub_list_1->append_element(ck3); + kernel_list->insert_element(sub_list_1); + + float64_t combs1[3]= {3, 5, 7}; + + combined_list = CCombinedKernel::combine_kernels(kernel_list); + index_t i = 0; + for (CSGObject* kernel=combined_list->get_first_element(); kernel; kernel=combined_list->get_next_element()) + { + CCombinedKernel* c_kernel = dynamic_cast(kernel); + CSGObject* subkernel = c_kernel->get_first_kernel(); + CGaussianKernel* c_subkernel = dynamic_cast(subkernel); + + EXPECT_EQ(c_kernel->get_num_subkernels(), 1); + EXPECT_EQ(c_subkernel->get_width(), combs1[i++]); + + SG_UNREF(subkernel); + SG_UNREF(kernel); + } + SG_UNREF(combined_list); + + CList * sub_list_2 = new CList(true); + CGaussianKernel* ck4 = new CGaussianKernel(20,21); + sub_list_2->append_element(ck4); + CGaussianKernel* ck5 = new CGaussianKernel(20,31); + sub_list_2->append_element(ck5); + kernel_list->append_element(sub_list_2); + + float64_t combs2[2][6] = {{ 21, 21, 21, 31, 31, 31}, + { 3, 5, 7, 3, 5, 7}}; + + combined_list = CCombinedKernel::combine_kernels(kernel_list); + + index_t j = 0; + for (CSGObject* kernel=combined_list->get_first_element(); kernel; kernel=combined_list->get_next_element()) + { + CCombinedKernel* c_kernel = dynamic_cast(kernel); + EXPECT_EQ(c_kernel->get_num_subkernels(), 2); + i = 0; + for (CSGObject* subkernel=c_kernel->get_first_kernel(); subkernel; subkernel=c_kernel->get_next_kernel()) + { + CGaussianKernel* c_subkernel = dynamic_cast(subkernel); + EXPECT_EQ(c_subkernel->get_width(), combs2[i++][j]); + SG_UNREF(subkernel); + } + ++j; + SG_UNREF(kernel); + } + + SG_UNREF(combined_list); + + CList* sub_list_3 = new CList(true); + CGaussianKernel* ck6 = new CGaussianKernel(25, 109); + sub_list_3->append_element(ck6); + CGaussianKernel* ck7 = new CGaussianKernel(25, 203); + sub_list_3->append_element(ck7); + CGaussianKernel* ck8 = new CGaussianKernel(25, 308); + sub_list_3->append_element(ck8); + CGaussianKernel* ck9 = new CGaussianKernel(25, 404); + sub_list_3->append_element(ck9); + kernel_list->append_element(sub_list_3); + + float64_t combs[3][24] = { + { 109, 109, 109, 109, 109, 109, 203, 203, 203, 203, 203, 203, 308, 308, 308, 308, 308, 308, 404, 404, 404, 404, 404, 404}, + { 21, 21, 21, 31, 31, 31, 21, 21, 21, 31, 31, 31, 21, 21, 21, 31, 31, 31, 21, 21, 21, 31, 31, 31}, + { 3, 5, 7, 3, 5, 7, 3, 5, 7, 3, 5, 7, 3, 5, 7, 3, 5, 7, 3, 5, 7, 3, 5, 7}}; + + combined_list = CCombinedKernel::combine_kernels(kernel_list); + + j = 0; + for (CSGObject* kernel=combined_list->get_first_element(); kernel; kernel=combined_list->get_next_element()) + { + CCombinedKernel* c_kernel = dynamic_cast(kernel); + i = 0; + EXPECT_EQ(c_kernel->get_num_subkernels(), 3); + for (CSGObject* subkernel=c_kernel->get_first_kernel(); subkernel; subkernel=c_kernel->get_next_kernel()) + { + CGaussianKernel* c_subkernel = dynamic_cast(subkernel); + EXPECT_EQ(c_subkernel->get_width(), combs[i++][j]); + SG_UNREF(subkernel); + } + ++j; + SG_UNREF(kernel); + } + + SG_UNREF(combined_list); + SG_UNREF(kernel_list); +}