-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
GaussianProcessRegression.cpp
142 lines (117 loc) · 3.81 KB
/
GaussianProcessRegression.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Jacob Walker, Roman Votyakov, Sergey Lisitsyn, Soeren Sonnenburg,
* Heiko Strathmann, Wu Lin
*/
#include <shogun/regression/GaussianProcessRegression.h>
#include <shogun/io/SGIO.h>
#include <shogun/machine/gp/FITCInferenceMethod.h>
using namespace shogun;
CGaussianProcessRegression::CGaussianProcessRegression()
: CGaussianProcessMachine()
{
}
CGaussianProcessRegression::CGaussianProcessRegression(CInference* method)
: CGaussianProcessMachine(method)
{
// set labels
m_labels=method->get_labels();
}
CGaussianProcessRegression::~CGaussianProcessRegression()
{
}
CRegressionLabels* CGaussianProcessRegression::apply_regression(CFeatures* data)
{
// check whether given combination of inference method and likelihood
// function supports regression
REQUIRE(m_method, "Inference method should not be NULL\n")
CLikelihoodModel* lik=m_method->get_model();
REQUIRE(m_method->supports_regression(), "%s with %s doesn't support "
"regression\n", m_method->get_name(), lik->get_name())
SG_UNREF(lik);
CRegressionLabels* result;
// if regression data equals to NULL, then apply regression on training
// features
if (!data)
{
CFeatures* feat;
// use inducing features for FITC inference method
if (m_method->get_inference_type()==INF_FITC_REGRESSION)
{
CFITCInferenceMethod* fitc_method = m_method->as<CFITCInferenceMethod>();
feat=fitc_method->get_inducing_features();
}
else
feat=m_method->get_features();
result=new CRegressionLabels(get_mean_vector(feat));
SG_UNREF(feat);
}
else
{
result=new CRegressionLabels(get_mean_vector(data));
}
return result;
}
bool CGaussianProcessRegression::train_machine(CFeatures* data)
{
// check whether given combination of inference method and likelihood
// function supports regression
REQUIRE(m_method, "Inference method should not be NULL\n")
CLikelihoodModel* lik=m_method->get_model();
REQUIRE(m_method->supports_regression(), "%s with %s doesn't support "
"regression\n", m_method->get_name(), lik->get_name())
SG_UNREF(lik);
if (data)
{
// set inducing features for FITC inference method
if (m_method->get_inference_type()==INF_FITC_REGRESSION)
{
CFITCInferenceMethod* fitc_method = m_method->as<CFITCInferenceMethod>();
fitc_method->set_inducing_features(data);
SG_UNREF(fitc_method);
}
else
m_method->set_features(data);
}
// perform inference
m_method->update();
return true;
}
SGVector<float64_t> CGaussianProcessRegression::get_mean_vector(CFeatures* data)
{
// check whether given combination of inference method and likelihood
// function supports regression
REQUIRE(m_method, "Inference method should not be NULL\n")
CLikelihoodModel* lik=m_method->get_model();
REQUIRE(m_method->supports_regression(), "%s with %s doesn't support "
"regression\n", m_method->get_name(), lik->get_name())
SG_UNREF(lik);
SG_REF(data);
SGVector<float64_t> mu=get_posterior_means(data);
SGVector<float64_t> s2=get_posterior_variances(data);
SG_UNREF(data);
// evaluate mean
lik=m_method->get_model();
mu=lik->get_predictive_means(mu, s2);
SG_UNREF(lik);
return mu;
}
SGVector<float64_t> CGaussianProcessRegression::get_variance_vector(
CFeatures* data)
{
// check whether given combination of inference method and likelihood
// function supports regression
REQUIRE(m_method, "Inference method should not be NULL\n")
CLikelihoodModel* lik=m_method->get_model();
REQUIRE(m_method->supports_regression(), "%s with %s doesn't support "
"regression\n", m_method->get_name(), lik->get_name())
SG_REF(data);
SGVector<float64_t> mu=get_posterior_means(data);
SGVector<float64_t> s2=get_posterior_variances(data);
SG_UNREF(data);
// evaluate variance
s2=lik->get_predictive_variances(mu, s2);
SG_UNREF(lik);
return s2;
}