diff --git a/src/shogun/statistical_testing/LinearTimeMMD.cpp b/src/shogun/statistical_testing/LinearTimeMMD.cpp index d1cc1cf3bd9..736ab1d6de8 100644 --- a/src/shogun/statistical_testing/LinearTimeMMD.cpp +++ b/src/shogun/statistical_testing/LinearTimeMMD.cpp @@ -16,6 +16,7 @@ * along with this program. If not, see . */ +#include #include #include #include @@ -30,14 +31,22 @@ CLinearTimeMMD::CLinearTimeMMD() : CMMD() { } +CLinearTimeMMD::CLinearTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q) : CMMD() +{ + set_p(samples_from_p); + set_q(samples_from_q); +} + CLinearTimeMMD::~CLinearTimeMMD() { } void CLinearTimeMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst) { - get_data_manager().set_blocksize(get_data_manager().get_min_blocksize()); - get_data_manager().set_num_blocks_per_burst(num_blocks_per_burst); + auto& dm=get_data_manager(); + dm.set_blocksize(get_data_manager().get_min_blocksize()); + dm.set_num_blocks_per_burst(num_blocks_per_burst); + SG_SDEBUG("Block contains %d and %d samples, from P and Q respectively!\n", dm.blocksize_at(0), dm.blocksize_at(1)); } const std::function)> CLinearTimeMMD::get_direct_estimation_method() const diff --git a/src/shogun/statistical_testing/LinearTimeMMD.h b/src/shogun/statistical_testing/LinearTimeMMD.h index 8f79304d401..900660fb459 100644 --- a/src/shogun/statistical_testing/LinearTimeMMD.h +++ b/src/shogun/statistical_testing/LinearTimeMMD.h @@ -29,6 +29,7 @@ class CLinearTimeMMD : public CMMD using operation=std::function)>; public: CLinearTimeMMD(); + CLinearTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q); virtual ~CLinearTimeMMD(); void set_num_blocks_per_burst(index_t num_blocks_per_burst);