-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
linear_time_mmd_kernel_selection.sg
43 lines (38 loc) · 1.32 KB
/
linear_time_mmd_kernel_selection.sg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
GaussianBlobsDataGenerator features_p_train()
GaussianBlobsDataGenerator features_q_train()
GaussianBlobsDataGenerator features_p_test()
GaussianBlobsDataGenerator features_q_test()
#![create_instance]
LinearTimeMMD mmd()
mmd.set_p(features_p_train)
mmd.set_q(features_q_train)
mmd.set_num_samples_p(100)
mmd.set_num_samples_q(100)
mmd.set_num_blocks_per_burst(100)
#![create_instance]
#![add_kernels]
GaussianKernel kernel1(10, 0.1)
GaussianKernel kernel2(10, 1)
GaussianKernel kernel3(10, 10)
mmd.add_kernel(kernel1)
mmd.add_kernel(kernel2)
mmd.add_kernel(kernel3)
#![add_kernels]
#![select_kernel_single]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER)
#GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel())
#Real width = learnt_kernel_single.get_width()
#![select_kernel_single]
#![select_kernel_combined]
mmd.select_kernel(enum EKernelSelectionMethod.MAXIMIZE_POWER, true)
#CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel())
#RealVector weights = learnt_kernel_combined.get_subkernel_weights()
#![select_kernel_combined]
#![perform_test]
mmd.set_p(features_p_test)
mmd.set_q(features_q_test)
mmd.set_num_samples_p(100)
mmd.set_num_samples_q(100)
Real statistic = mmd.compute_statistic()
Real p_value = mmd.compute_p_value(statistic)
#![perform_test]