Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
using factory methods in kernel ridge regression meta example (#4297)
* using factory methods * cast labels to regression * data update
- Loading branch information
1 parent
ee3da84
commit 911d35d
Showing
3 changed files
with
14 additions
and
19 deletions.
There are no files selected for viewing
Submodule data
updated
1 files
+3 −5 | testsuite/meta/regression/kernel_ridge_regression.dat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,36 @@ | ||
CSVFile f_feats_train("../../data/regression_1d_sinc_features_train.dat") | ||
CSVFile f_feats_test("../../data/regression_1d_sinc_features_test.dat") | ||
CSVFile f_labels_train("../../data/regression_1d_sinc_labels_train.dat") | ||
CSVFile f_labels_test("../../data/regression_1d_sinc_labels_test.dat") | ||
File f_feats_train = csv_file("../../data/regression_1d_sinc_features_train.dat") | ||
File f_feats_test = csv_file("../../data/regression_1d_sinc_features_test.dat") | ||
File f_labels_train = csv_file("../../data/regression_1d_sinc_labels_train.dat") | ||
File f_labels_test = csv_file("../../data/regression_1d_sinc_labels_test.dat") | ||
|
||
#![create_features] | ||
Features features_train = features(f_feats_train) | ||
Features features_test = features(f_feats_test) | ||
RegressionLabels labels_train(f_labels_train) | ||
RegressionLabels labels_test(f_labels_test) | ||
Labels labels_train = labels(f_labels_train) | ||
Labels labels_test = labels(f_labels_test) | ||
#![create_features] | ||
|
||
#![create_appropriate_kernel] | ||
Kernel k = kernel("GaussianKernel", log_width=0.0) | ||
#![create_appropriate_kernel] | ||
|
||
#![create_instance] | ||
real tau = 0.001 | ||
KernelRidgeRegression krr(tau, k, labels_train) | ||
Machine krr = machine("KernelRidgeRegression", labels=labels_train, tau=0.001, kernel=k) | ||
#![create_instance] | ||
|
||
#![train_and_apply] | ||
krr.train(features_train) | ||
RegressionLabels labels_predict = krr.apply_regression(features_test) | ||
Labels labels_predict = krr.apply_regression(features_test) | ||
#![train_and_apply] | ||
|
||
#![extract_alpha] | ||
RealVector alpha = krr.get_alphas() | ||
RealVector alphas = krr.get_real_vector("m_alpha") | ||
#![extract_alpha] | ||
|
||
#![evaluate_error] | ||
MeanSquaredError eval() | ||
Evaluation eval = evaluation("MeanSquaredError") | ||
real mse = eval.evaluate(labels_predict, labels_test) | ||
#![evaluate_error] | ||
|
||
# integration testing variables | ||
RealVector output = labels_test.get_labels() | ||
RealVector output = labels_test.get_real_vector("labels") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters