diff --git a/src/shogun/statistical_testing/LinearTimeMMD.cpp b/src/shogun/statistical_testing/LinearTimeMMD.cpp
index d1cc1cf3bd9..736ab1d6de8 100644
--- a/src/shogun/statistical_testing/LinearTimeMMD.cpp
+++ b/src/shogun/statistical_testing/LinearTimeMMD.cpp
@@ -16,6 +16,7 @@
* along with this program. If not, see .
*/
+#include
#include
#include
#include
@@ -30,14 +31,22 @@ CLinearTimeMMD::CLinearTimeMMD() : CMMD()
{
}
+CLinearTimeMMD::CLinearTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q) : CMMD()
+{
+ set_p(samples_from_p);
+ set_q(samples_from_q);
+}
+
CLinearTimeMMD::~CLinearTimeMMD()
{
}
void CLinearTimeMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst)
{
- get_data_manager().set_blocksize(get_data_manager().get_min_blocksize());
- get_data_manager().set_num_blocks_per_burst(num_blocks_per_burst);
+ auto& dm=get_data_manager();
+ dm.set_blocksize(get_data_manager().get_min_blocksize());
+ dm.set_num_blocks_per_burst(num_blocks_per_burst);
+ SG_SDEBUG("Block contains %d and %d samples, from P and Q respectively!\n", dm.blocksize_at(0), dm.blocksize_at(1));
}
const std::function)> CLinearTimeMMD::get_direct_estimation_method() const
diff --git a/src/shogun/statistical_testing/LinearTimeMMD.h b/src/shogun/statistical_testing/LinearTimeMMD.h
index 8f79304d401..900660fb459 100644
--- a/src/shogun/statistical_testing/LinearTimeMMD.h
+++ b/src/shogun/statistical_testing/LinearTimeMMD.h
@@ -29,6 +29,7 @@ class CLinearTimeMMD : public CMMD
using operation=std::function)>;
public:
CLinearTimeMMD();
+ CLinearTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q);
virtual ~CLinearTimeMMD();
void set_num_blocks_per_burst(index_t num_blocks_per_burst);