Skip to content

Commit

Permalink
Merge pull request #106 from schalkdaniel/general_updates
Browse files Browse the repository at this point in the history
Add test for CustomCpp
  • Loading branch information
Daniel Schalk committed Jan 27, 2018
2 parents 6a384eb + 9b9e6cf commit c611369
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 22 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ LazyData: true
RoxygenNote: 6.0.1
Imports:
Rcpp (>= 0.11.2),
RcppArmadillo,
methods,
glue
LinkingTo:
Expand Down
81 changes: 81 additions & 0 deletions tests/testthat/test_baselearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,85 @@ test_that("custom baselearner works correctly", {
custom$predictNewdata(as.matrix(newdata, ncol = 1)),
as.matrix(unname(predict(mod, newdata = data.frame(x = newdata))))
)
})

test_that("CustomCpp baselearner works", {

Rcpp::sourceCpp(code = '
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
typedef arma::mat (*instantiateDataFunPtr) (arma::mat& X);
typedef arma::mat (*trainFunPtr) (arma::vec& y, arma::mat& X);
typedef arma::mat (*predictFunPtr) (arma::mat& newdata, arma::mat& parameter);
// instantiateDataFun:
// -------------------
arma::mat instantiateDataFun (arma::mat& X)
{
return X;
}
// trainFun:
// -------------------
arma::mat trainFun (arma::vec& y, arma::mat& X)
{
return arma::solve(X, y);
}
// predictFun:
// -------------------
arma::mat predictFun (arma::mat& newdata, arma::mat& parameter)
{
return newdata * parameter;
}
// Setter function:
// ------------------
// [[Rcpp::export]]
Rcpp::XPtr<instantiateDataFunPtr> dataFunSetter ()
{
return Rcpp::XPtr<instantiateDataFunPtr> (new instantiateDataFunPtr (&instantiateDataFun));
}
// [[Rcpp::export]]
Rcpp::XPtr<trainFunPtr> trainFunSetter ()
{
return Rcpp::XPtr<trainFunPtr> (new trainFunPtr (&trainFun));
}
// [[Rcpp::export]]
Rcpp::XPtr<predictFunPtr> predictFunSetter ()
{
return Rcpp::XPtr<predictFunPtr> (new predictFunPtr (&predictFun));
}'
)

x = 1:10
X = matrix(x, ncol = 1)
y = 3 * 1:10 + rnorm(10)
newdata = runif(10, 1, 10)

X.test = as.matrix(runif(200))

custom.cpp.blearner = CustomCpp$new(X, "my_variable_name", dataFunSetter(),
trainFunSetter(), predictFunSetter())

custom.cpp.blearner$train(y)

mod = lm(y ~ 0 + x)

expect_equal(custom.cpp.blearner$getData(), X)
expect_equal(as.numeric(custom.cpp.blearner$getParameter()), unname(coef(mod)))
expect_equal(custom.cpp.blearner$predict(), as.matrix(unname(predict(mod))))
expect_equal(
custom.cpp.blearner$predictNewdata(as.matrix(newdata, ncol = 1)),
as.matrix(unname(predict(mod, newdata = data.frame(x = newdata))))
)
})
91 changes: 69 additions & 22 deletions tests/testthat/test_factory.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,72 @@ test_that("custom factory works", {
})


# test_that("custom cpp factory works", {
#
# suppressWarnings(
# Rcpp::sourceCpp("../../external_test_files/custom_cpp_learner.cpp")
# )
#
# set.seed(pi)
# X = matrix(1:10, ncol = 1)
# y = 3 * as.numeric(X) + rnorm(10, 0, 2)
#
# X.test = as.matrix(runif(200))
#
# custom.cpp.factory = CustomCppFactory$new(X, "my_variable_name", dataFunSetter(),
# trainFunSetter(), predictFunSetter())
#
# custom.cpp.factory$testTrain(y)
#
# expect_equal(custom.cpp.factory$getData(), X)
# expect_equal(custom.cpp.factory$testGetParameter(), solve(t(X) %*% X) %*% t(X) %*% y)
# expect_equal(custom.cpp.factory$testPredict(), X %*% solve(t(X) %*% X) %*% t(X) %*% y)
# expect_equal(custom.cpp.factory$testPredictNewdata(X.test), X.test %*% solve(t(X) %*% X) %*% t(X) %*% y)
# })
test_that("custom cpp factory works", {

Rcpp::sourceCpp(code = '
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
typedef arma::mat (*instantiateDataFunPtr) (arma::mat& X);
typedef arma::mat (*trainFunPtr) (arma::vec& y, arma::mat& X);
typedef arma::mat (*predictFunPtr) (arma::mat& newdata, arma::mat& parameter);
// instantiateDataFun:
// -------------------
arma::mat instantiateDataFun (arma::mat& X)
{
return X;
}
// trainFun:
// -------------------
arma::mat trainFun (arma::vec& y, arma::mat& X)
{
return arma::solve(X, y);
}
// predictFun:
// -------------------
arma::mat predictFun (arma::mat& newdata, arma::mat& parameter)
{
return newdata * parameter;
}
// Setter function:
// ------------------
// [[Rcpp::export]]
Rcpp::XPtr<instantiateDataFunPtr> dataFunSetter ()
{
return Rcpp::XPtr<instantiateDataFunPtr> (new instantiateDataFunPtr (&instantiateDataFun));
}
// [[Rcpp::export]]
Rcpp::XPtr<trainFunPtr> trainFunSetter ()
{
return Rcpp::XPtr<trainFunPtr> (new trainFunPtr (&trainFun));
}
// [[Rcpp::export]]
Rcpp::XPtr<predictFunPtr> predictFunSetter ()
{
return Rcpp::XPtr<predictFunPtr> (new predictFunPtr (&predictFun));
}'
)

set.seed(pi)
X = matrix(1:10, ncol = 1)
y = 3 * as.numeric(X) + rnorm(10, 0, 2)

X.test = as.matrix(runif(200))

custom.cpp.factory = CustomCppFactory$new(X, "my_variable_name", dataFunSetter(),
trainFunSetter(), predictFunSetter())

expect_equal(custom.cpp.factory$getData(), X)
})

0 comments on commit c611369

Please sign in to comment.