diff --git a/tests/testthat/test_factory.R b/tests/testthat/test_factory.R index 04874a44..e5c3481f 100644 --- a/tests/testthat/test_factory.R +++ b/tests/testthat/test_factory.R @@ -110,3 +110,27 @@ test_that("custom factory works", { predictFun(mod.test, X.test) ) }) + + +# test_that("custom cpp factory works", { +# +# suppressWarnings( +# Rcpp::sourceCpp("../../external_test_files/custom_cpp_learner.cpp") +# ) +# +# set.seed(pi) +# X = matrix(1:10, ncol = 1) +# y = 3 * as.numeric(X) + rnorm(10, 0, 2) +# +# X.test = as.matrix(runif(200)) +# +# custom.cpp.factory = CustomCppFactory$new(X, "my_variable_name", dataFunSetter(), +# trainFunSetter(), predictFunSetter()) +# +# custom.cpp.factory$testTrain(y) +# +# expect_equal(custom.cpp.factory$getData(), X) +# expect_equal(custom.cpp.factory$testGetParameter(), solve(t(X) %*% X) %*% t(X) %*% y) +# expect_equal(custom.cpp.factory$testPredict(), X %*% solve(t(X) %*% X) %*% t(X) %*% y) +# expect_equal(custom.cpp.factory$testPredictNewdata(X.test), X.test %*% solve(t(X) %*% X) %*% t(X) %*% y) +# }) \ No newline at end of file