Skip to content

Commit

Permalink
code style in checking internals (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
qiushiyan committed Aug 17, 2022
1 parent 47e708b commit 9e36249
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 38 deletions.
89 changes: 53 additions & 36 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ make_classes <- function(prefix) {
#' @return If an error is not thrown (from non-empty ellipses), a NULL list.
#' @keywords internal
#' @export
check_empty_ellipse <- function (...) {
check_empty_ellipse <- function(...) {
terms <- quos(...)
if (!is_empty(terms))
if (!is_empty(terms)) {
rlang::abort("Please pass other arguments to the model function via `set_engine()`.")
}
terms
}

is_missing_arg <- function(x)
is_missing_arg <- function(x) {
identical(x, quote(missing_arg()))
}

model_info_table <-
utils::read.delim(system.file("models.tsv", package = "parsnip"))
Expand All @@ -38,7 +40,11 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {
if (isFALSE(mode_ %in% c("regression", "censored regression", "classification"))) {
mode_ <- c("regression", "censored regression", "classification")
}
eng_cond <- if (is.null(engine_)) {TRUE} else {quote(engine == engine_)}
eng_cond <- if (is.null(engine_)) {
TRUE
} else {
quote(engine == engine_)
}

avail <-
get_from_env(spec_) %>%
Expand All @@ -56,7 +62,7 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {

is_printable_spec <- function(x) {
!is.null(x$method$fit$args) &&
has_loaded_implementation(class(x)[1], x$engine, x$mode)
has_loaded_implementation(class(x)[1], x$engine, x$mode)
}

# construct a message informing the user that there are no
Expand Down Expand Up @@ -109,22 +115,25 @@ show_call <- function(object) {
map(object$method$fit$args, convert_arg)

call2(object$method$fit$func["fun"],
!!!object$method$fit$args,
.ns = object$method$fit$func["pkg"])
!!!object$method$fit$args,
.ns = object$method$fit$func["pkg"]
)
}

convert_arg <- function(x) {
if (is_quosure(x))
if (is_quosure(x)) {
quo_get_expr(x)
else
} else {
x
}
}

levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark"))
if (inherits(dat, "tbl_spark")) {
res <- NULL
else
} else {
res <- levels(eval_tidy(f[[2]], dat))
}
res
}

Expand All @@ -134,7 +143,7 @@ levels_from_formula <- function(f, dat) {
show_fit <- function(model, eng) {
mod <- translate(x = model, engine = eng)
fit_call <- show_call(mod)
call_text <- deparse(fit_call)
call_text <- deparse(fit_call)
call_text <- paste0(call_text, collapse = "\n")
paste0(
"\\preformatted{\n",
Expand All @@ -157,9 +166,10 @@ check_args.default <- function(object) {

# copied form recipes

names0 <- function (num, prefix = "x") {
if (num < 1)
names0 <- function(num, prefix = "x") {
if (num < 1) {
rlang::abort("`num` should be > 0.")
}
ind <- format(1:num)
ind <- gsub(" ", "0", ind)
paste0(prefix, ind)
Expand All @@ -172,16 +182,16 @@ names0 <- function (num, prefix = "x") {
#' @keywords internal
#' @rdname add_on_exports
update_dot_check <- function(...) {

dots <- enquos(...)

if (length(dots) > 0)
if (length(dots) > 0) {
rlang::abort(
glue::glue(
"Extra arguments will be ignored: ",
glue::glue_collapse(glue::glue("`{names(dots)}`"), sep = ", ")
)
)
}
invisible(NULL)
}

Expand All @@ -192,15 +202,16 @@ update_dot_check <- function(...) {
#' @rdname add_on_exports
new_model_spec <- function(cls, args, eng_args, mode, method, engine,
check_missing_spec = TRUE) {

check_spec_mode_engine_val(cls, engine, mode)

if ((!has_loaded_implementation(cls, engine, mode)) && check_missing_spec) {
rlang::inform(inform_missing_implementation(cls, engine, mode))
}

out <- list(args = args, eng_args = eng_args,
mode = mode, method = method, engine = engine)
out <- list(
args = args, eng_args = eng_args,
mode = mode, method = method, engine = engine
)
class(out) <- make_classes(cls)
out
}
Expand All @@ -211,8 +222,9 @@ check_outcome <- function(y, spec) {
if (spec$mode == "unknown") {
return(invisible(NULL))
} else if (spec$mode == "regression") {
if (!all(map_lgl(y, is.numeric)))
if (!all(map_lgl(y, is.numeric))) {
rlang::abort("For a regression model, the outcome should be numeric.")
}
} else if (spec$mode == "classification") {
if (!all(map_lgl(y, is.factor))) {
rlang::abort("For a classification model, the outcome should be a factor.")
Expand Down Expand Up @@ -250,7 +262,6 @@ check_final_param <- function(x) {
#' @keywords internal
#' @rdname add_on_exports
update_main_parameters <- function(args, param) {

if (length(param) == 0) {
return(args)
}
Expand All @@ -263,8 +274,10 @@ update_main_parameters <- function(args, param) {
extra_args <- names(param)[has_extra_args]
if (any(has_extra_args)) {
rlang::abort(
paste("At least one argument is not a main argument:",
paste0("`", extra_args, "`", collapse = ", "))
paste(
"At least one argument is not a main argument:",
paste0("`", extra_args, "`", collapse = ", ")
)
)
}
param <- param[!has_extra_args]
Expand All @@ -276,7 +289,6 @@ update_main_parameters <- function(args, param) {
#' @keywords internal
#' @rdname add_on_exports
update_engine_parameters <- function(eng_args, fresh, ...) {

dots <- enquos(...)

## only update from dots when there are eng args in original model spec
Expand All @@ -303,16 +315,20 @@ update_engine_parameters <- function(eng_args, fresh, ...) {
stan_conf_int <- function(object, newdata) {
check_installs(list(method = list(libs = "rstanarm")))
if (utils::packageVersion("rstanarm") >= "2.21.1") {
fn <- rlang::call2("posterior_epred", .ns = "rstanarm",
object = expr(object),
newdata = expr(newdata),
seed = expr(sample.int(10^5, 1)))
fn <- rlang::call2("posterior_epred",
.ns = "rstanarm",
object = expr(object),
newdata = expr(newdata),
seed = expr(sample.int(10^5, 1))
)
} else {
fn <- rlang::call2("posterior_linpred", .ns = "rstanarm",
object = expr(object),
newdata = expr(newdata),
transform = TRUE,
seed = expr(sample.int(10^5, 1)))
fn <- rlang::call2("posterior_linpred",
.ns = "rstanarm",
object = expr(object),
newdata = expr(newdata),
transform = TRUE,
seed = expr(sample.int(10^5, 1))
)
}
rlang::eval_tidy(fn)
}
Expand Down Expand Up @@ -357,30 +373,31 @@ stan_conf_int <- function(object, newdata) {
#' @keywords internal
#' @export
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {

if (is.null(penalty)) {
penalty <- object$fit$lambda
}

# when using `predict()`, allow for a single lambda
if (!multi) {
if (length(penalty) != 1)
if (length(penalty) != 1) {
rlang::abort(
glue::glue(
"`penalty` should be a single numeric value. `multi_predict()` ",
"can be used to get multiple predictions per row of data.",
)
)
}
}

if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
rlang::abort(
glue::glue(
"The glmnet model was fit with a single penalty value of ",
"{object$fit$lambda}. Predicting with a value of {penalty} ",
"will give incorrect results from `glmnet()`."
)
)
}

penalty
}
Expand Down
1 change: 1 addition & 0 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ print.model_spec <- function(x, ...) {
#' @rdname add_on_exports
#' @export
print_model_spec <- function(x, cls = class(x)[1], desc = get_model_desc(cls), ...) {

cat(desc, " Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

Expand Down
6 changes: 5 additions & 1 deletion R/required_pkgs.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ get_pkgs <- function(x, infra) {
pkgs <-
get_from_env(paste0(cls, "_pkgs")) %>%
dplyr::filter(engine == x$engine)
res <- pkgs$pkg[[1]]
if (length(pkgs$pkg) == 0) {
res <- character(0)
} else {
res <- pkgs$pkg[[1]]
}
if (length(res) == 0) {
res <- character(0)
}
Expand Down
1 change: 0 additions & 1 deletion parsnip.Rproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@ StripTrailingWhitespace: Yes

BuildType: Package
PackageUseDevtools: Yes
PackageCleanBeforeInstall: Yes
PackageInstallArgs: --no-multiarch --with-keep.source
PackageRoxygenize: rd,collate,namespace

0 comments on commit 9e36249

Please sign in to comment.