From 44cb92a7e141ff776c4492e154d9ca032d287127 Mon Sep 17 00:00:00 2001 From: lambday Date: Fri, 3 Jun 2016 16:29:39 +0100 Subject: [PATCH] updated kernel selection unit-tests --- .../KernelSelectionMaxMMD_unittest.cc | 121 ------- .../KernelSelectionMaxPower_unittest.cc | 121 ------- .../KernelSelectionMaxXValidation_unittest.cc | 75 ---- ...KernelSelectionMedianHeuristic_unittest.cc | 113 ------- .../KernelSelection_unittest.cc | 320 ++++++++++++++++++ 5 files changed, 320 insertions(+), 430 deletions(-) delete mode 100644 tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc delete mode 100644 tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc delete mode 100644 tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc delete mode 100644 tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc create mode 100644 tests/unit/statistical_testing/KernelSelection_unittest.cc diff --git a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc deleted file mode 100644 index 291fd447224..00000000000 --- a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (W) 2012-2013 Heiko Strathmann - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace shogun; - -TEST(KernelSelectionMaxMMD, single_kernel) -{ - const index_t m=20; - const index_t n=30; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=10; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - mmd->set_num_blocks_per_burst(1000); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); -} - -TEST(KernelSelectionMaxMMD, weighted_kernel) -{ - const index_t m=20; - const index_t n=30; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=10; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - mmd->set_num_blocks_per_burst(1000); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD, true); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto weighted_kernel=dynamic_cast(mmd->get_kernel()); - ASSERT_TRUE(weighted_kernel!=nullptr); - ASSERT_TRUE(weighted_kernel->get_num_subkernels()==num_kernels); - SGVector weights=weighted_kernel->get_subkernel_weights(); - weights.display_vector("weights"); // TODO remove -} diff --git a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc deleted file mode 100644 index 63a6e2101eb..00000000000 --- a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (W) 2012-2013 Heiko Strathmann - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace shogun; - -TEST(KernelSelectionMaxPower, single_kernel) -{ - const index_t m=20; - const index_t n=30; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=10; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - mmd->set_num_blocks_per_burst(1000); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); -} - -TEST(KernelSelectionMaxPower, weighted_kernel) -{ - const index_t m=24; - const index_t n=32; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=4; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - mmd->set_num_blocks_per_burst(1000); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER, true); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto weighted_kernel=dynamic_cast(mmd->get_kernel()); - ASSERT_TRUE(weighted_kernel!=nullptr); - ASSERT_TRUE(weighted_kernel->get_num_subkernels()==num_kernels); - SGVector weights=weighted_kernel->get_subkernel_weights(); - weights.display_vector("weights"); // TODO remove -} diff --git a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc deleted file mode 100644 index d501db8e6ca..00000000000 --- a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include -#include -#include -#include -#include -#include -#include - -using namespace shogun; - -TEST(KernelSelectionMaxXValidation, single_kernel) -{ - const index_t m=5; - const index_t n=10; - const index_t dim=1; - const float64_t difference=0.5; - const index_t num_kernels=10; - const index_t num_runs=1; - const index_t num_folds=5; - const float64_t alpha=0.05; - - sg_rand->set_seed(12345); - - auto gen_p=some(0, dim, 0); - auto gen_q=some(difference, dim, 0); - auto feats_p=gen_p->get_streamed_features(m); - auto feats_q=gen_q->get_streamed_features(n); - - auto mmd=some(feats_p, feats_q); - - for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); - } - - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_XVALIDATION, num_runs, alpha); - mmd->set_train_test_mode(true); - mmd->set_train_test_ratio(num_folds-1); - mmd->select_kernel(); - mmd->set_train_test_mode(false); - - auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); -} diff --git a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc deleted file mode 100644 index 242f5520369..00000000000 --- a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (W) 2012-2013 Heiko Strathmann - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace shogun; - -TEST(KernelSelectionMedianHeuristic, quadratic_time_mmd) -{ - const index_t m=20; - const index_t n=30; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=10; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 1.62, 1E-10); -} - -TEST(KernelSelectionMedianHeuristic, linear_time_mmd) -{ - const index_t m=20; - const index_t n=30; - const index_t dim=2; - const float64_t difference=0.5; - const index_t num_kernels=10; - - // use fixed seed - sg_rand->set_seed(12345); - - // streaming data generator for mean shift distributions - auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); - auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); - - // create MMD instance, convienience constructor - auto mmd=some(gen_p, gen_q); - mmd->set_statistic_type(ST_BIASED_FULL); - mmd->set_num_samples_p(m); - mmd->set_num_samples_q(n); - - for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); - } - - mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); - mmd->set_train_test_mode(true); - mmd->select_kernel(); - auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 1.62, 1E-10); -} diff --git a/tests/unit/statistical_testing/KernelSelection_unittest.cc b/tests/unit/statistical_testing/KernelSelection_unittest.cc new file mode 100644 index 00000000000..218972561c3 --- /dev/null +++ b/tests/unit/statistical_testing/KernelSelection_unittest.cc @@ -0,0 +1,320 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (W) 2012-2013 Heiko Strathmann + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace shogun; + +TEST(KernelSelectionMaxMMD, linear_time_single_kernel_streaming) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + + sg_rand->set_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + mmd->set_num_blocks_per_burst(1000); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); +} + +TEST(KernelSelectionMaxMMD, linear_time_weighted_kernel_streaming) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + + sg_rand->set_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + mmd->set_num_blocks_per_burst(1000); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD, true); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto weighted_kernel=dynamic_cast(mmd->get_kernel()); + ASSERT_TRUE(weighted_kernel!=nullptr); + ASSERT_TRUE(weighted_kernel->get_num_subkernels()==num_kernels); + + SGVector weights=weighted_kernel->get_subkernel_weights(); + for (auto i=0; iset_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + mmd->set_num_blocks_per_burst(1000); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); +} + +TEST(KernelSelectionMaxTestPower, linear_time_weighted_kernel_streaming) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + + sg_rand->set_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + mmd->set_num_blocks_per_burst(1000); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER, true); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto weighted_kernel=dynamic_cast(mmd->get_kernel()); + ASSERT_TRUE(weighted_kernel!=nullptr); + ASSERT_TRUE(weighted_kernel->get_num_subkernels()==num_kernels); + + SGVector weights=weighted_kernel->get_subkernel_weights(); + for (auto i=0; iset_seed(12345); + + auto gen_p=some(0, dim, 0); + auto gen_q=some(difference, dim, 0); + auto feats_p=gen_p->get_streamed_features(m); + auto feats_q=gen_q->get_streamed_features(n); + + auto mmd=some(feats_p, feats_q); + mmd->set_statistic_type(ST_BIASED_FULL); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_XVALIDATION, num_runs, alpha); + + mmd->set_train_test_mode(true); + mmd->set_train_test_ratio(num_folds-1); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); +} + +TEST(KernelSelectionMaxXValidation, linear_time_single_kernel_dense) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + const index_t num_runs=1; + const index_t num_folds=5; + const float64_t alpha=0.05; + + sg_rand->set_seed(12345); + + auto gen_p=some(0, dim, 0); + auto gen_q=some(difference, dim, 0); + auto feats_p=gen_p->get_streamed_features(m); + auto feats_q=gen_q->get_streamed_features(n); + + auto mmd=some(feats_p, feats_q); + mmd->set_statistic_type(ST_BIASED_FULL); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_XVALIDATION, num_runs, alpha); + + mmd->set_train_test_mode(true); + mmd->set_train_test_ratio(num_folds-1); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); +} + +TEST(KernelSelectionMedianHeuristic, quadratic_time_single_kernel_dense) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + + sg_rand->set_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 1.0, 1E-10); +} + +TEST(KernelSelectionMedianHeuristic, linear_time_single_kernel_dense) +{ + const index_t m=5; + const index_t n=10; + const index_t dim=1; + const float64_t difference=0.5; + const index_t num_kernels=10; + + sg_rand->set_seed(12345); + + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto mmd=some(gen_p, gen_q); + mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_num_samples_p(m); + mmd->set_num_samples_q(n); + for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); + + mmd->set_train_test_mode(true); + mmd->select_kernel(); + mmd->set_train_test_mode(false); + + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 1.0, 1E-10); +}