Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added compiling (but not running) examples of kernel selection for MMD
- Loading branch information
Showing
2 changed files
with
86 additions
and
0 deletions.
There are no files selected for viewing
43 changes: 43 additions & 0 deletions
43
examples/meta/src/statistical_testing/linear_time_mmd_kernel_selection.sg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
GaussianBlobsDataGenerator features_p_train() | ||
GaussianBlobsDataGenerator features_q_train() | ||
GaussianBlobsDataGenerator features_p_test() | ||
GaussianBlobsDataGenerator features_q_test() | ||
|
||
#![create_instance] | ||
LinearTimeMMD mmd() | ||
mmd.set_p(features_p_train) | ||
mmd.set_q(features_q_train) | ||
mmd.set_num_samples_p(100) | ||
mmd.set_num_samples_q(100) | ||
mmd.set_num_blocks_per_burst(100) | ||
#![create_instance] | ||
|
||
#![add_kernels] | ||
GaussianKernel kernel1(10, 0.1) | ||
GaussianKernel kernel2(10, 1) | ||
GaussianKernel kernel3(10, 10) | ||
mmd.add_kernel(kernel1) | ||
mmd.add_kernel(kernel2) | ||
mmd.add_kernel(kernel3) | ||
#![add_kernels] | ||
|
||
#![select_kernel_single] | ||
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER) | ||
GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel()) | ||
Real width = learnt_kernel_single.get_width() | ||
#![select_kernel_single] | ||
|
||
#![select_kernel_combined] | ||
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER) | ||
CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel()) | ||
RealVector weights = learnt_kernel_combined.get_subkernel_weights() | ||
#![select_kernel_combined] | ||
|
||
#![perform_test] | ||
mmd.set_p(features_p_test) | ||
mmd.set_q(features_q_test) | ||
mmd.set_num_samples_p(100) | ||
mmd.set_num_samples_q(100) | ||
Real statistic = mmd.compute_statistic() | ||
Real p_value = mmd.compute_p_value(statistic) | ||
#![perform_test] |
43 changes: 43 additions & 0 deletions
43
examples/meta/src/statistical_testing/quadratic_time_mmd_kernel_selection.sg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
CSVFile f_features_p_train("../../data/two_sample_test_gaussian.dat") | ||
CSVFile f_features_q_train("../../data/two_sample_test_laplace.dat") | ||
CSVFile f_features_p_test("../../data/two_sample_test_gaussian.dat") | ||
CSVFile f_features_q_test("../../data/two_sample_test_laplace.dat") | ||
|
||
#![create_features] | ||
RealFeatures features_p_train(f_features_p_train) | ||
RealFeatures features_q_train(f_features_q_train) | ||
RealFeatures features_p_test(f_features_p_test) | ||
RealFeatures features_q_test(f_features_q_test) | ||
#![create_features] | ||
|
||
#![create_instance] | ||
QuadraticTimeMMD mmd(features_p_train, features_q_train) | ||
#![create_instance] | ||
|
||
#![add_kernels] | ||
GaussianKernel kernel1(10, 0.1) | ||
GaussianKernel kernel2(10, 1) | ||
GaussianKernel kernel3(10, 10) | ||
mmd.add_kernel(kernel1) | ||
mmd.add_kernel(kernel2) | ||
mmd.add_kernel(kernel3) | ||
#![add_kernels] | ||
|
||
#![select_kernel_single] | ||
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_MMD) | ||
GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel()) | ||
Real width = learnt_kernel_single.get_width() | ||
#![select_kernel_single] | ||
|
||
#![select_kernel_combined] | ||
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_MMD) | ||
CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel()) | ||
RealVector weights = learnt_kernel_combined.get_subkernel_weights() | ||
#![select_kernel_combined] | ||
|
||
#![perform_test] | ||
mmd.set_p(features_p_test) | ||
mmd.set_q(features_q_test) | ||
Real statistic = mmd.compute_statistic() | ||
Real p_value = mmd.compute_p_value(statistic) | ||
#![perform_test] |