Skip to content

Commit

Permalink
intermediate work for test-train data split
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent e7d7b7c commit 755b8fe
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 69 deletions.
6 changes: 0 additions & 6 deletions src/shogun/statistical_testing/BTestMMD.cpp
Expand Up @@ -110,12 +110,6 @@ float64_t CBTestMMD::compute_threshold(float64_t alpha)
return result;
}

std::shared_ptr<CCustomDistance> CBTestMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
return distance;
}

const char* CBTestMMD::get_name() const
{
return "BTestMMD";
Expand Down
1 change: 0 additions & 1 deletion src/shogun/statistical_testing/BTestMMD.h
Expand Up @@ -42,7 +42,6 @@ class CBTestMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
};

}
Expand Down
6 changes: 0 additions & 6 deletions src/shogun/statistical_testing/LinearTimeMMD.cpp
Expand Up @@ -145,12 +145,6 @@ float64_t CLinearTimeMMD::compute_threshold(float64_t alpha)
return result;
}

std::shared_ptr<CCustomDistance> CLinearTimeMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
return distance;
}

const char* CLinearTimeMMD::get_name() const
{
return "LinearTimeMMD";
Expand Down
1 change: 0 additions & 1 deletion src/shogun/statistical_testing/LinearTimeMMD.h
Expand Up @@ -42,7 +42,6 @@ class CLinearTimeMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
const float64_t gaussian_variance(float64_t variance) const;
};

Expand Down
57 changes: 56 additions & 1 deletion src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -35,6 +35,8 @@
#include <shogun/kernel/CustomKernel.h>
#include <shogun/kernel/CombinedKernel.h>
#include <shogun/features/Features.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
#include <shogun/statistical_testing/BTestMMD.h>
Expand Down Expand Up @@ -74,6 +76,7 @@ struct CMMD::Self

std::pair<float64_t, float64_t> compute_statistic_variance();
std::pair<SGVector<float64_t>, SGMatrix<float64_t>> compute_statistic_and_Q();
std::shared_ptr<CCustomDistance> compute_distance();
SGVector<float64_t> sample_null();

CMMD& owner;
Expand Down Expand Up @@ -386,6 +389,57 @@ SGVector<float64_t> CMMD::Self::sample_null()
return statistic;
}

std::shared_ptr<CCustomDistance> CMMD::Self::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
DataManager& dm=owner.get_data_manager();

bool blockwise=dm.is_blockwise();
dm.set_blockwise(false);

// using data manager next() API in order to make it work with
// streaming samples as well.
dm.start();
auto samples=dm.next();
if (!samples.empty())
{
dm.end();

// use 0th block from each distribution (since there is only one block
// for quadratic time MMD
CFeatures *samples_p=samples[0][0].get();
CFeatures *samples_q=samples[1][0].get();

try
{
auto p_and_q=FeaturesUtil::create_merged_copy(samples_p, samples_q);
samples.clear();
auto euclidean_distance=std::unique_ptr<CEuclideanDistance>(new CEuclideanDistance());
if (euclidean_distance->init(p_and_q, p_and_q))
{
auto dist_mat=euclidean_distance->get_distance_matrix<float32_t>();
distance->set_triangle_distance_matrix_from_full(dist_mat.data(), dist_mat.num_rows, dist_mat.num_cols);
}
else
{
SG_SERROR("Computing distance matrix was not possible! Please contact Shogun developers.\n");
}
}
catch (ShogunException e)
{
SG_SERROR("%s, Data is too large! Computing distance matrix was not possible!\n", e.get_exception_string());
}
}
else
{
dm.end();
SG_SERROR("Could not fetch samples!\n");
}

dm.set_blockwise(blockwise);
return distance;
}

CMMD::CMMD() : CTwoSampleTest()
{
#if EIGEN_VERSION_AT_LEAST(3,1,0)
Expand Down Expand Up @@ -419,9 +473,10 @@ void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel, f
case KSM_MEDIAN_HEURISTIC:
{
REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=compute_distance();
auto distance=self->compute_distance();
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(self->kernel_selection_mgr, distance));
dm.set_train_test_ratio(0);
dm.reset();
}
break;
case KSM_MAXIMIZE_XVALIDATION:
Expand Down
1 change: 0 additions & 1 deletion src/shogun/statistical_testing/MMD.h
Expand Up @@ -123,7 +123,6 @@ class CMMD : public CTwoSampleTest
virtual const operation get_direct_estimation_method() const=0;
virtual const float64_t normalize_statistic(float64_t statistic) const=0;
virtual const float64_t normalize_variance(float64_t variance) const=0;
virtual std::shared_ptr<CCustomDistance> compute_distance()=0;
bool use_gpu() const;
private:
struct Self;
Expand Down
52 changes: 0 additions & 52 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -36,8 +36,6 @@
#include <shogun/kernel/CustomKernel.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
#include <shogun/statistical_testing/internals/FeaturesUtil.h>
#include <shogun/statistical_testing/internals/NextSamples.h>
Expand Down Expand Up @@ -487,56 +485,6 @@ SGVector<float64_t> CQuadraticTimeMMD::spectrum_sample_null()
return null_samples;
}

std::shared_ptr<CCustomDistance> CQuadraticTimeMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
DataManager& dm=get_data_manager();

// using data manager next() API in order to make it work with
// streaming samples as well.
dm.start();
auto samples=dm.next();
if (!samples.empty())
{
dm.end();

// use 0th block from each distribution (since there is only one block
// for quadratic time MMD
CFeatures *samples_p=samples[0][0].get();
CFeatures *samples_q=samples[1][0].get();

try
{
auto p_and_q=FeaturesUtil::create_merged_copy(samples_p, samples_q);
samples.clear();
auto euclidean_distance=std::unique_ptr<CEuclideanDistance>(new CEuclideanDistance());
if (euclidean_distance->init(p_and_q, p_and_q))
{
auto dist_mat=euclidean_distance->get_distance_matrix<float32_t>();
if (io->get_loglevel()==MSG_DEBUG)
{
dist_mat.display_matrix("distance_matrix");
}
distance->set_triangle_distance_matrix_from_full(dist_mat.data(), dist_mat.num_rows, dist_mat.num_cols);
}
else
{
SG_SERROR("Computing distance matrix was not possible! Please contact Shogun developers.\n");
}
}
catch (ShogunException e)
{
SG_SERROR("%s, Data is too large! Computing distance matrix was not possible!\n", e.get_exception_string());
}
}
else
{
dm.end();
SG_SERROR("Could not fetch samples!\n");
}
return distance;
}

const char* CQuadraticTimeMMD::get_name() const
{
return "QuadraticTimeMMD";
Expand Down
1 change: 0 additions & 1 deletion src/shogun/statistical_testing/QuadraticTimeMMD.h
Expand Up @@ -67,7 +67,6 @@ class CQuadraticTimeMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
SGVector<float64_t> gamma_fit_null();
SGVector<float64_t> spectrum_sample_null();
};
Expand Down
18 changes: 18 additions & 0 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -196,6 +196,24 @@ const index_t DataManager::blocksize_at(size_t i) const
return fetchers[i]->m_block_details.m_blocksize;
}

const bool DataManager::is_blockwise() const
{
SG_SDEBUG("Entering!\n");
bool blockwise=true;
for (size_t i=0; i<fetchers.size(); ++i)
blockwise&=!fetchers[i]->m_block_details.m_full_data;
SG_SDEBUG("Leaving!\n");
return blockwise;
}

void DataManager::set_blockwise(bool blockwise)
{
SG_SDEBUG("Entering!\n");
for (size_t i=0; i<fetchers.size(); ++i)
fetchers[i]->m_block_details.m_full_data=!blockwise;
SG_SDEBUG("Leaving!\n");
}

void DataManager::set_train_test_ratio(float64_t train_test_ratio)
{
SG_SDEBUG("Entering!\n");
Expand Down
13 changes: 13 additions & 0 deletions src/shogun/statistical_testing/internals/DataManager.h
Expand Up @@ -176,6 +176,19 @@ class DataManager
*/
const index_t blocksize_at(size_t i) const;

/**
* @return True if block-wise fetching is on, False otherwise.
*/
const bool is_blockwise() const;

/**
* Turns on blockwise fetching if True is passed. Turns off blockwise fetching if
* False is passed. The blockwise details are not destroyed when set to False, i.e.
* turning blockwise fetching back on again, we can get blocks as we would have got
* in the original setup.
*/
void set_blockwise(bool blockwise);

/**
* @return Total number of samples that can be fetched from all the data sources.
*/
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/statistical_testing/internals/DataManager_unittest.cc
Expand Up @@ -875,3 +875,55 @@ TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise)
ASSERT_TRUE(total==num_vec);
mgr.end();
}

TEST(DataManager, set_blockwise_on_off)
{
const index_t dim=3;
const index_t num_vec=8;
const index_t blocksize=2;
const index_t num_blocks_per_burst=2;
const index_t num_distributions=1;

SGMatrix<float64_t> data_p(dim, num_vec);
std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0);

auto feats_p=new CDenseFeatures<float64_t>(data_p);

DataManager mgr(num_distributions);
mgr.samples_at(0)=feats_p;
mgr.set_blocksize(blocksize);
mgr.set_num_blocks_per_burst(num_blocks_per_burst);

mgr.set_blockwise(false);
mgr.start();
auto next_burst=mgr.next();
ASSERT_TRUE(!next_burst.empty());
ASSERT_TRUE(next_burst.num_blocks()==1);
auto casted=dynamic_cast<CDenseFeatures<float64_t>*>(next_burst[0][0].get());
casted->get_feature_matrix().display_matrix("whole");
ASSERT_TRUE(casted!=nullptr);
ASSERT_TRUE(casted->get_num_vectors()==num_vec);
next_burst=mgr.next();
ASSERT_TRUE(next_burst.empty());
mgr.end();

mgr.reset();
mgr.set_blockwise(true);
mgr.start();
auto total=0;
next_burst=mgr.next();
while (!next_burst.empty())
{
// ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst);
for (auto i=0; i<next_burst.num_blocks(); ++i)
{
auto tmp=dynamic_cast<CDenseFeatures<float64_t>*>(next_burst[0][i].get());
ASSERT_TRUE(tmp!=nullptr);
tmp->get_feature_matrix().display_matrix("block");
// ASSERT_TRUE(tmp->get_num_vectors()==blocksize);
total+=tmp->get_num_vectors();
}
next_burst=mgr.next();
}
// ASSERT_TRUE(total==num_vec);
}

0 comments on commit 755b8fe

Please sign in to comment.