diff --git a/R/compboost.R b/R/compboost.R index b62b3761..cc4c7ebf 100644 --- a/R/compboost.R +++ b/R/compboost.R @@ -337,8 +337,10 @@ Compboost = R6::R6Class("Compboost", checkmate::assertNumeric(learning_rate, lower = 0, upper = 1, any.missing = FALSE, len = 1) checkmate::assertNumeric(oob_fraction, lower = 0, upper = 1, any.missing = FALSE, len = 1, null.ok = TRUE) - if (! target %in% names(data)) { - stop ("The target ", target, " is not present within the data") + if (! isRcppClass(target, "Response")) { + if (! target %in% names(data)) { + stop ("The target ", target, " is not present within the data") + } } if (inherits(loss, "C++Class")) { stop ("Loss should be an initialized loss object by calling the constructor: ", deparse(substitute(loss)), "$new()") @@ -369,7 +371,7 @@ Compboost = R6::R6Class("Compboost", self$oob_fraction = oob_fraction self$target = self$response$getTargetName() - self$data = data[private$train_idx, !colnames(data) %in% target, drop = FALSE] + self$data = data[private$train_idx, !colnames(data) %in% self$target, drop = FALSE] self$optimizer = optimizer self$loss = loss self$learning_rate = learning_rate diff --git a/R/helper.R b/R/helper.R index 9084e5e5..200e0211 100644 --- a/R/helper.R +++ b/R/helper.R @@ -1,7 +1,7 @@ assertRcppClass = function (x, x_class, stop_when.error = TRUE) { cls = class(x) - rcpp_class = TRUE + if (! grepl("Rcpp", cls)) { stop("Object was not exposed by Rcpp.") } @@ -10,6 +10,20 @@ assertRcppClass = function (x, x_class, stop_when.error = TRUE) } } +isRcppClass = function (x, x_class) +{ + cls = class(x) + is_rcpp_class = TRUE + + if (! grepl("Rcpp", cls)) { + is_rcpp_class = FALSE + } + if (! grepl(x_class, cls)) { + is_rcpp_class = FALSE + } + return(is_rcpp_class) +} + vectorToResponse = function (vec, target) { # Transform factor or character labels to -1 and 1 diff --git a/vignettes/extending_compboost.Rmd b/vignettes/extending_compboost.Rmd index d47fba1d..6d2d2436 100644 --- a/vignettes/extending_compboost.Rmd +++ b/vignettes/extending_compboost.Rmd @@ -23,7 +23,7 @@ library(compboost) `compboost` was designed to provide a component-wise boosting framework with maximal flexibility. This document gives an overview about the two main -possibilities of extending compboost with custom functions. We will have +possibilities of extending compboost with custom functions. We will take a look at: - Using custom losses.