-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
gaussian_process_regression.sg
46 lines (39 loc) · 1.74 KB
/
gaussian_process_regression.sg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
CSVFile f_feats_train("../../data/regression_1d_sinc_features_train.dat")
CSVFile f_feats_test("../../data/regression_1d_sinc_features_test.dat")
CSVFile f_labels_train("../../data/regression_1d_sinc_labels_train.dat")
CSVFile f_labels_test("../../data/regression_1d_sinc_labels_test.dat")
#![create_features]
RealFeatures features_train(f_feats_train)
RealFeatures features_test(f_feats_test)
RegressionLabels labels_train(f_labels_train)
RegressionLabels labels_test(f_labels_test)
#![create_features]
#![create_appropriate_kernel_and_mean_function]
real width = 1.0
GaussianKernel kernel(features_train, features_train, width)
ZeroMean mean_function()
#![create_appropriate_kernel_and_mean_function]
#![create_instance]
GaussianLikelihood gauss_likelihood()
ExactInferenceMethod inference_method(kernel, features_train, mean_function, labels_train, gauss_likelihood)
GaussianProcessRegression gp_regression(inference_method)
#![create_instance]
#![train_and_apply]
gp_regression.train()
RegressionLabels labels_predict = gp_regression.apply_regression(features_test)
#![train_and_apply]
#![optimize_marginal_likelihood]
GradientCriterion grad_criterion()
GradientEvaluation grad(gp_regression, features_train, labels_train, grad_criterion)
grad.set_function(inference_method)
GradientModelSelection grad_selection(grad)
ParameterCombination best_theta = grad_selection.select_model()
best_theta.apply_to_machine(gp_regression)
#![optimize_marginal_likelihood]
#![evaluate_error_and_marginal_likelihood]
MeanSquaredError eval()
real mse = eval.evaluate(labels_predict, labels_test)
real marg_ll = inference_method.get_negative_log_marginal_likelihood()
#![evaluate_error_and_marginal_likelihood]
# integration testing variables
RealVector output = labels_test.get_labels()