-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
linear_time_mmd.sg
59 lines (50 loc) · 1.63 KB
/
linear_time_mmd.sg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
GaussianBlobsDataGenerator features_p()
GaussianBlobsDataGenerator features_q()
#![create_instance]
LinearTimeMMD mmd()
GaussianKernel kernel(10, 1)
mmd.set_kernel(kernel)
mmd.set_p(features_p)
mmd.set_q(features_q)
mmd.set_num_samples_p(1000)
mmd.set_num_samples_q(1000)
real alpha = 0.05
#![create_instance]
#![set_burst]
mmd.set_num_blocks_per_burst(1000)
#![set_burst]
#![estimate_mmd]
real statistic = mmd.compute_statistic()
#![estimate_mmd]
#![perform_test]
real threshold = mmd.compute_threshold(alpha)
real p_value = mmd.compute_p_value(statistic)
#![perform_test]
#![add_kernels]
GaussianKernel kernel1(10, 0.1)
GaussianKernel kernel2(10, 1)
GaussianKernel kernel3(10, 10)
mmd.add_kernel(kernel1)
mmd.add_kernel(kernel2)
mmd.add_kernel(kernel3)
#![add_kernels]
#![enable_train_test_mode]
mmd.set_train_test_mode(True)
mmd.set_train_test_ratio(1)
#![enable_train_test_mode]
#![select_kernel_single]
mmd.set_kernel_selection_strategy(enum EKernelSelectionMethod.KSM_MAXIMIZE_POWER)
mmd.select_kernel()
GaussianKernel learnt_kernel_single = GaussianKernel:obtain_from_generic(mmd.get_kernel())
real width = learnt_kernel_single.get_width()
#![select_kernel_single]
#![select_kernel_combined]
mmd.set_kernel_selection_strategy(enum EKernelSelectionMethod.KSM_MAXIMIZE_POWER, True)
mmd.select_kernel()
CombinedKernel learnt_kernel_combined = CombinedKernel:obtain_from_generic(mmd.get_kernel())
RealVector weights = learnt_kernel_combined.get_subkernel_weights()
#![select_kernel_combined]
#![perform_test_optimized]
real statistic_optimized = mmd.compute_statistic()
real p_value_optimized = mmd.compute_p_value(statistic)
#![perform_test_optimized]