-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
multiple_kernel_learning.sg
49 lines (41 loc) · 1.74 KB
/
multiple_kernel_learning.sg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
File f_feats_train = csv_file("../../data/regression_1d_sinc_features_train.dat")
File f_feats_test = csv_file("../../data/regression_1d_sinc_features_test.dat")
File f_labels_train = csv_file("../../data/regression_1d_sinc_labels_train.dat")
File f_labels_test = csv_file("../../data/regression_1d_sinc_labels_test.dat")
#![create_features]
Features features_train = features(f_feats_train)
Features features_test = features(f_feats_test)
Labels labels_train = labels(f_labels_train)
Labels labels_test = labels(f_labels_test)
#![create_features]
#![create_kernel]
Kernel poly_kernel = kernel("PolyKernel", degree=2)
Kernel gauss_kernel_1 = kernel("GaussianKernel", log_width=0.0)
Kernel gauss_kernel_2 = kernel("GaussianKernel", log_width=1.0)
#![create_kernel]
#![create_combined_train]
Kernel combined_kernel = kernel("CombinedKernel")
combined_kernel.add("kernel_array", poly_kernel)
combined_kernel.add("kernel_array", gauss_kernel_1)
combined_kernel.add("kernel_array", gauss_kernel_2)
combined_kernel.init(features_train, features_train)
#![create_combined_train]
#![train_mkl]
Machine binary_svm_solver = machine("SVRLight")
Machine mkl = machine("MKLRegression", svm=binary_svm_solver, kernel=combined_kernel, labels=labels_train)
mkl.train()
#![train_mkl]
#![extract_weights]
RealVector beta = combined_kernel.get_subkernel_weights()
RealVector alpha = mkl.get_real_vector("m_alpha")
#![extract_weights]
#![mkl_apply]
combined_kernel.init(features_train, features_test)
Labels labels_predict = mkl.apply()
#![mkl_apply]
#![evaluate_error]
Evaluation eval = evaluation("MeanSquaredError")
real mse = eval.evaluate(labels_predict, labels_test)
#![evaluate_error]
# additional integration testing variables
RealVector output = labels_predict.get_real_vector("labels")