diff --git a/NEWS.md b/NEWS.md index b2e05433a..0f76cee0c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## New Features * A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification). +* `fit_xy()` can take a single column data frame or matrix for `y` without error ## Other Changes @@ -38,7 +39,7 @@ First CRAN release # parsnip 0.0.0.9005 -* The engine, and any associated arguments, are now specified using `set_engine`. There is no `engine` argument +* The engine, and any associated arguments, are now specified using `set_engine()`. There is no `engine` argument # parsnip 0.0.0.9004 @@ -64,7 +65,7 @@ First CRAN release # parsnip 0.0.0.9000 -* The `fit` interface was previously used to cover both the x/y interface as well as the formula interface. Now, `fit` is the formula interface and [`fit_xy` is for the x/y interface](https://github.com/topepo/parsnip/issues/33). +* The `fit` interface was previously used to cover both the x/y interface as well as the formula interface. Now, `fit()` is the formula interface and [`fit_xy()` is for the x/y interface](https://github.com/topepo/parsnip/issues/33). * Added a `NEWS.md` file to track changes to the package. * `predict` methods were [overhauled](https://github.com/topepo/parsnip/issues/34) to be [consistent](https://github.com/topepo/parsnip/issues/41). * MARS was added. diff --git a/R/fit.R b/R/fit.R index 594b9d21a..7e971e044 100644 --- a/R/fit.R +++ b/R/fit.R @@ -180,6 +180,14 @@ fit_xy.model_spec <- if (any(names(dots) == "engine")) stop("Use `set_engine()` to supply the engine.", call. = FALSE) + if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) { + if (is.matrix(y)) { + y <- y[, 1] + } else { + y <- y[[1]] + } + } + cl <- match.call(expand.dots = TRUE) eval_env <- rlang::env() eval_env$x <- x diff --git a/tests/testthat/test_fit_interfaces.R b/tests/testthat/test_fit_interfaces.R index 075660b80..c06a8b537 100644 --- a/tests/testthat/test_fit_interfaces.R +++ b/tests/testthat/test_fit_interfaces.R @@ -23,7 +23,7 @@ test_that('good args', { expect_equal( tester(NULL, formula = f, data = iris, model = rmod), "formula") expect_equal(tester_xy(NULL, x = iris, y = iris, model = rmod), "data.frame") expect_equal( tester(NULL, f, data = iris, model = rmod), "formula") - expect_equal( tester(NULL, f, data = sprk, model = rmod), "formula") + expect_equal( tester(NULL, f, data = sprk, model = rmod), "formula") }) #test_that('unnamed args', { @@ -37,3 +37,26 @@ test_that('wrong args', { expect_error(tester(NULL, f, data = as.matrix(iris[, 1:4]))) }) +test_that('single column df for issue #129', { + + expect_error( + lm1 <- + linear_reg() %>% + set_engine("lm") %>% + fit_xy(x = mtcars[, 2:4], y = mtcars[,1, drop = FALSE]), + regexp = NA + ) + expect_error( + lm2 <- + linear_reg() %>% + set_engine("lm") %>% + fit_xy(x = mtcars[, 2:4], y = as.matrix(mtcars)[,1, drop = FALSE]), + regexp = NA + ) + lm3 <- + linear_reg() %>% + set_engine("lm") %>% + fit_xy(x = mtcars[, 2:4], y = mtcars$mpg) + expect_equal(coef(lm1), coef(lm3)) + expect_equal(coef(lm2), coef(lm3)) +})