From dfdcb0350667aa886737e920fb635d85de3049e8 Mon Sep 17 00:00:00 2001 From: Heiko Strathmann Date: Fri, 1 Apr 2016 10:42:47 +0100 Subject: [PATCH] add constructor to QuadraticTimeMMD --- .../meta/src/statistical_testing/quadratic_time_mmd.sg | 5 ++--- src/shogun/statistical_testing/QuadraticTimeMMD.cpp | 8 ++++++++ src/shogun/statistical_testing/QuadraticTimeMMD.h | 2 ++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/meta/src/statistical_testing/quadratic_time_mmd.sg b/examples/meta/src/statistical_testing/quadratic_time_mmd.sg index cd57cb3467e..bfb15e54a54 100644 --- a/examples/meta/src/statistical_testing/quadratic_time_mmd.sg +++ b/examples/meta/src/statistical_testing/quadratic_time_mmd.sg @@ -7,11 +7,9 @@ RealFeatures features_q(f_features_q) #![create_features] #![create_instance] -QuadraticTimeMMD mmd() +QuadraticTimeMMD mmd(features_p, features_q) GaussianKernel kernel() mmd.set_kernel(kernel) -mmd.set_p(features_p) -mmd.set_q(features_q) Real alpha = 0.05 #![create_instance] @@ -36,6 +34,7 @@ Real h0_rejected_permutation = mmd.perform_test(alpha) #![perform_test_spectrum] mmd.set_null_approximation_method(enum ENullApproximationMethod.MMD2_SPECTRUM) mmd.set_num_null_samples(200) +mmd.spectrum_set_num_eigenvalues(5) Real threshold_spectrum = mmd.compute_threshold(alpha) Real p_value_spectrum = mmd.compute_p_value(statistic_biased) Real h0_rejected_spectrum = mmd.perform_test(alpha) diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index 41f72a29ec8..cf3fe5aad6f 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -46,6 +46,14 @@ CQuadraticTimeMMD::CQuadraticTimeMMD() : CMMD() self = std::unique_ptr(new Self()); } +CQuadraticTimeMMD::CQuadraticTimeMMD(CFeatures* samples_from_p, + CFeatures* samples_from_q) : CMMD() +{ + self = std::unique_ptr(new Self()); + set_p(samples_from_p); + set_q(samples_from_p); +} + CQuadraticTimeMMD::~CQuadraticTimeMMD() { } diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.h b/src/shogun/statistical_testing/QuadraticTimeMMD.h index 7deac1bfcee..12dc1ba7a20 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.h +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.h @@ -32,6 +32,8 @@ class CQuadraticTimeMMD : public CMMD using operation = std::function)>; public: CQuadraticTimeMMD(); + CQuadraticTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q); + virtual ~CQuadraticTimeMMD(); virtual SGVector sample_null() override;