Skip to content

Commit

Permalink
made the enum classes in statistical testing simple enums
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 4, 2016
1 parent 7b57b63 commit d35d8df
Show file tree
Hide file tree
Showing 20 changed files with 141 additions and 144 deletions.
Expand Up @@ -22,13 +22,13 @@ mmd.add_kernel(kernel3)
#![add_kernels]

#![select_kernel_single]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER)
mmd.select_kernel(KSM_MAXIMIZE_POWER)
GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel())
Real width = learnt_kernel_single.get_width()
#![select_kernel_single]

#![select_kernel_combined]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER, true)
mmd.select_kernel(KSM_MAXIMIZE_POWER, true)
CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel())
RealVector weights = learnt_kernel_combined.get_subkernel_weights()
#![select_kernel_combined]
Expand Down
10 changes: 5 additions & 5 deletions examples/meta/src/statistical_testing/quadratic_time_mmd.sg
Expand Up @@ -14,32 +14,32 @@ Real alpha = 0.05
#![create_instance]

#![estimate_mmd_unbiased]
mmd.set_statistic_type(enum EStatisticType.UNBIASED_FULL)
mmd.set_statistic_type(ST_UNBIASED_FULL)
Real statistic_unbiased = mmd.compute_statistic()
#![estimate_mmd_unbiased]

#![estimate_mmd_biased]
mmd.set_statistic_type(enum EStatisticType.BIASED_FULL)
mmd.set_statistic_type(ST_BIASED_FULL)
Real statistic_biased = mmd.compute_statistic()
#![estimate_mmd_biased]

#![perform_test_permutation]
mmd.set_null_approximation_method(enum ENullApproximationMethod.PERMUTATION)
mmd.set_null_approximation_method(NAM_PERMUTATION)
mmd.set_num_null_samples(200)
Real threshold_permutation = mmd.compute_threshold(alpha)
Real p_value_permutation = mmd.compute_p_value(statistic_biased)
#![perform_test_permutation]

#![perform_test_spectrum]
mmd.set_null_approximation_method(enum ENullApproximationMethod.MMD2_SPECTRUM)
mmd.set_null_approximation_method(NAM_MMD2_SPECTRUM)
mmd.set_num_null_samples(200)
mmd.spectrum_set_num_eigenvalues(5)
Real threshold_spectrum = mmd.compute_threshold(alpha)
Real p_value_spectrum = mmd.compute_p_value(statistic_biased)
#![perform_test_spectrum]

#![perform_test_gamma]
mmd.set_null_approximation_method(enum ENullApproximationMethod.MMD2_GAMMA)
mmd.set_null_approximation_method(NAM_MMD2_GAMMA)
Real threshold_gamma = mmd.compute_threshold(alpha)
Real p_value_gamma = mmd.compute_p_value(statistic_biased)
#![perform_test_gamma]
Expand Up @@ -24,13 +24,13 @@ mmd.add_kernel(kernel3)
#![add_kernels]

#![select_kernel_single]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_MMD)
mmd.select_kernel(KSM_MAXIMIZE_MMD)
GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel())
Real width = learnt_kernel_single.get_width()
#![select_kernel_single]

#![select_kernel_combined]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_MMD, true)
mmd.select_kernel(KSM_MAXIMIZE_MMD, true)
CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel())
RealVector weights = learnt_kernel_combined.get_subkernel_weights()
#![select_kernel_combined]
Expand Down
44 changes: 22 additions & 22 deletions src/shogun/statistical_testing/BTestMMD.cpp
Expand Up @@ -52,37 +52,37 @@ const std::function<float32_t(SGMatrix<float32_t>)> CBTestMMD::get_direct_estima

const float64_t CBTestMMD::normalize_statistic(float64_t statistic) const
{
const DataManager& dm = get_data_manager();
const index_t Nx = dm.num_samples_at(0);
const index_t Ny = dm.num_samples_at(1);
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
return Nx * Ny * statistic * CMath::sqrt((Bx + By)/float64_t(Nx + Ny)) / (Nx + Ny);
const DataManager& dm=get_data_manager();
const index_t Nx=dm.num_samples_at(0);
const index_t Ny=dm.num_samples_at(1);
const index_t Bx=dm.blocksize_at(0);
const index_t By=dm.blocksize_at(1);
return Nx*Ny*statistic*CMath::sqrt((Bx+By)/float64_t(Nx+Ny))/(Nx+Ny);
}

const float64_t CBTestMMD::normalize_variance(float64_t variance) const
{
const DataManager& dm = get_data_manager();
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
return variance * CMath::sq(Bx * By / float64_t(Bx + By));
const DataManager& dm=get_data_manager();
const index_t Bx=dm.blocksize_at(0);
const index_t By=dm.blocksize_at(1);
return variance*CMath::sq(Bx*By/float64_t(Bx+By));
}

float64_t CBTestMMD::compute_p_value(float64_t statistic)
{
float64_t result = 0;
float64_t result=0;
switch (get_null_approximation_method())
{
case ENullApproximationMethod::MMD1_GAUSSIAN:
case NAM_MMD1_GAUSSIAN:
{
float64_t sigma_sq = compute_variance();
float64_t std_dev = CMath::sqrt(sigma_sq);
result = 1.0 - CStatistics::normal_cdf(statistic, std_dev);
float64_t sigma_sq=compute_variance();
float64_t std_dev=CMath::sqrt(sigma_sq);
result=1.0-CStatistics::normal_cdf(statistic, std_dev);
break;
}
default:
{
result = CHypothesisTest::compute_p_value(statistic);
result=CHypothesisTest::compute_p_value(statistic);
break;
}
}
Expand All @@ -91,19 +91,19 @@ float64_t CBTestMMD::compute_p_value(float64_t statistic)

float64_t CBTestMMD::compute_threshold(float64_t alpha)
{
float64_t result = 0;
float64_t result=0;
switch (get_null_approximation_method())
{
case ENullApproximationMethod::MMD1_GAUSSIAN:
case NAM_MMD1_GAUSSIAN:
{
float64_t sigma_sq = compute_variance();
float64_t std_dev = CMath::sqrt(sigma_sq);
result = 1.0 - CStatistics::inverse_normal_cdf(1 - alpha, 0, std_dev);
float64_t sigma_sq=compute_variance();
float64_t std_dev=CMath::sqrt(sigma_sq);
result=1.0-CStatistics::inverse_normal_cdf(1-alpha, 0, std_dev);
break;
}
default:
{
result = CHypothesisTest::compute_threshold(alpha);
result=CHypothesisTest::compute_threshold(alpha);
break;
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/statistical_testing/LinearTimeMMD.cpp
Expand Up @@ -83,7 +83,7 @@ const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
const index_t B = Bx + By;
if (get_statistic_type() == EStatisticType::UNBIASED_INCOMPLETE)
if (get_statistic_type() == ST_UNBIASED_INCOMPLETE)
{
return variance * B * (B - 2) / 16;
}
Expand All @@ -96,7 +96,7 @@ const float64_t CLinearTimeMMD::gaussian_variance(float64_t variance) const
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
const index_t B = Bx + By;
if (get_statistic_type() == EStatisticType::UNBIASED_INCOMPLETE)
if (get_statistic_type() == ST_UNBIASED_INCOMPLETE)
{
return variance * 4 / (B - 2);
}
Expand All @@ -108,7 +108,7 @@ float64_t CLinearTimeMMD::compute_p_value(float64_t statistic)
float64_t result = 0;
switch (get_null_approximation_method())
{
case ENullApproximationMethod::MMD1_GAUSSIAN:
case NAM_MMD1_GAUSSIAN:
{
float64_t sigma_sq = gaussian_variance(compute_variance());
float64_t std_dev = CMath::sqrt(sigma_sq);
Expand All @@ -129,7 +129,7 @@ float64_t CLinearTimeMMD::compute_threshold(float64_t alpha)
float64_t result = 0;
switch (get_null_approximation_method())
{
case ENullApproximationMethod::MMD1_GAUSSIAN:
case NAM_MMD1_GAUSSIAN:
{
float64_t sigma_sq = gaussian_variance(compute_variance());
float64_t std_dev = CMath::sqrt(sigma_sq);
Expand Down
28 changes: 14 additions & 14 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -94,9 +94,9 @@ struct CMMD::Self

CMMD::Self::Self(CMMD& cmmd) : owner(cmmd),
use_gpu(false), num_null_samples(250),
statistic_type(EStatisticType::UNBIASED_FULL),
variance_estimation_method(EVarianceEstimationMethod::DIRECT),
null_approximation_method(ENullApproximationMethod::PERMUTATION),
statistic_type(ST_UNBIASED_FULL),
variance_estimation_method(VEM_DIRECT),
null_approximation_method(NAM_PERMUTATION),
statistic_job(nullptr), variance_job(nullptr)
{
}
Expand All @@ -114,13 +114,13 @@ void CMMD::Self::create_statistic_job()
auto By=dm.blocksize_at(1);
switch (statistic_type)
{
case EStatisticType::UNBIASED_FULL:
case ST_UNBIASED_FULL:
statistic_job=mmd::UnbiasedFull(Bx);
break;
case EStatisticType::UNBIASED_INCOMPLETE:
case ST_UNBIASED_INCOMPLETE:
statistic_job=mmd::UnbiasedIncomplete(Bx);
break;
case EStatisticType::BIASED_FULL:
case ST_BIASED_FULL:
statistic_job=mmd::BiasedFull(Bx);
break;
default : break;
Expand All @@ -132,10 +132,10 @@ void CMMD::Self::create_variance_job()
{
switch (variance_estimation_method)
{
case EVarianceEstimationMethod::DIRECT:
case VEM_DIRECT:
variance_job=owner.get_direct_estimation_method();
break;
case EVarianceEstimationMethod::PERMUTATION:
case VEM_PERMUTATION:
variance_job=permutation_job;
break;
default : break;
Expand Down Expand Up @@ -225,7 +225,7 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()
statistic_term_counter++;
}

if (variance_estimation_method==EVarianceEstimationMethod::DIRECT)
if (variance_estimation_method==VEM_DIRECT)
{
for (size_t i=0; i<mmds.size(); ++i)
{
Expand All @@ -252,7 +252,7 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()

// normalize statistic and variance
statistic=owner.normalize_statistic(statistic);
if (variance_estimation_method==EVarianceEstimationMethod::PERMUTATION)
if (variance_estimation_method==VEM_PERMUTATION)
variance=owner.normalize_variance(variance);

return std::make_pair(statistic, variance);
Expand Down Expand Up @@ -416,27 +416,27 @@ void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel, f

switch (kmethod)
{
case EKernelSelectionMethod::MEDIAN_HEURISTIC:
case KSM_MEDIAN_HEURISTIC:
{
REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=compute_distance();
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(self->kernel_selection_mgr, distance));
dm.set_train_test_ratio(0);
}
break;
case EKernelSelectionMethod::MAXIMIZE_XVALIDATION:
case KSM_MAXIMIZE_XVALIDATION:
{
REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n");
policy=std::unique_ptr<MaxXValidation>(new MaxXValidation(self->kernel_selection_mgr, this, num_run, alpha));
}
break;
case EKernelSelectionMethod::MAXIMIZE_MMD:
case KSM_MAXIMIZE_MMD:
if (weighted_kernel)
policy=std::unique_ptr<WeightedMaxMeasure>(new WeightedMaxMeasure(self->kernel_selection_mgr, this));
else
policy=std::unique_ptr<MaxMeasure>(new MaxMeasure(self->kernel_selection_mgr, this));
break;
case EKernelSelectionMethod::MAXIMIZE_POWER:
case KSM_MAXIMIZE_POWER:
if (weighted_kernel)
policy=std::unique_ptr<WeightedMaxTestPower>(new WeightedMaxTestPower(self->kernel_selection_mgr, this));
else
Expand Down
39 changes: 19 additions & 20 deletions src/shogun/statistical_testing/MMD.h
Expand Up @@ -53,35 +53,34 @@ class WeightedMaxTestPower;

}

// TODO change enum class to enum in order to co-operate with python swig etc blah
enum class EStatisticType
enum EStatisticType
{
UNBIASED_FULL,
UNBIASED_INCOMPLETE,
BIASED_FULL
ST_UNBIASED_FULL,
ST_UNBIASED_INCOMPLETE,
ST_BIASED_FULL
};

enum class EVarianceEstimationMethod
enum EVarianceEstimationMethod
{
DIRECT,
PERMUTATION
VEM_DIRECT,
VEM_PERMUTATION
};

enum class ENullApproximationMethod
enum ENullApproximationMethod
{
PERMUTATION,
MMD1_GAUSSIAN,
MMD2_SPECTRUM,
MMD2_GAMMA
NAM_PERMUTATION,
NAM_MMD1_GAUSSIAN,
NAM_MMD2_SPECTRUM,
NAM_MMD2_GAMMA
};

enum class EKernelSelectionMethod
enum EKernelSelectionMethod
{
MEDIAN_HEURISTIC,
MAXIMIZE_MMD,
MAXIMIZE_POWER,
MAXIMIZE_XVALIDATION,
AUTO
KSM_MEDIAN_HEURISTIC,
KSM_MAXIMIZE_MMD,
KSM_MAXIMIZE_POWER,
KSM_MAXIMIZE_XVALIDATION,
KSM_AUTO
};

class CMMD : public CTwoSampleTest
Expand All @@ -95,7 +94,7 @@ class CMMD : public CTwoSampleTest
virtual ~CMMD();

void add_kernel(CKernel *kernel);
void select_kernel(EKernelSelectionMethod kmethod=EKernelSelectionMethod::AUTO,
void select_kernel(EKernelSelectionMethod kmethod=KSM_AUTO,
bool weighted_kernel=false, float64_t train_test_ratio=1.0,
index_t num_run=10, float64_t alpha=0.05);

Expand Down

0 comments on commit d35d8df

Please sign in to comment.