-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add composition to preprocessor blueprints #150
Conversation
Related to tidymodels/tidymodels#42 |
Closes #100 eventually |
I worked more on this today, and now during library(hardhat)
library(recipes)
#> Loading required package: dplyr
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
#>
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#>
#> step
train <- iris[1:100,]
test <- iris[101:150,]
rec <- recipe(Species ~ Sepal.Length + Sepal.Width, train) %>%
step_normalize(Sepal.Length)
sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
sparse_bp
#> Recipe blueprint:
#>
#> # Predictors: 0
#> # Outcomes: 0
#> Intercept: FALSE
#> Novel Levels: FALSE
#> Composition: dgCMatrix
processed <- mold(rec, train, blueprint = sparse_bp)
processed$blueprint
#> Recipe blueprint:
#>
#> # Predictors: 2
#> # Outcomes: 1
#> Intercept: FALSE
#> Novel Levels: FALSE
#> Composition: dgCMatrix
forge(train, blueprint = processed$blueprint)
#> $predictors
#> 100 x 2 sparse Matrix of class "dgCMatrix"
#> Sepal.Length Sepal.Width
#> [1,] -0.57815327 3.5
#> [2,] -0.88982620 3.0
#> [3,] -1.20149912 3.2
#> [4,] -1.35733558 3.1
#> [5,] -0.73398974 3.6
#> [6,] -0.11064389 3.9
#> [7,] -1.35733558 3.4
#> [8,] -0.73398974 3.4
#> [9,] -1.66900851 2.9
#> [10,] -0.88982620 3.1
#> [11,] -0.11064389 3.7
#> [12,] -1.04566266 3.4
#> [13,] -1.04566266 3.0
#> [14,] -1.82484497 3.0
#> [15,] 0.51270196 4.0
#> [16,] 0.35686550 4.4
#> [17,] -0.11064389 3.9
#> [18,] -0.57815327 3.5
#> [19,] 0.35686550 3.8
#> [20,] -0.57815327 3.8
#> [21,] -0.11064389 3.4
#> [22,] -0.57815327 3.7
#> [23,] -1.35733558 3.6
#> [24,] -0.57815327 3.3
#> [25,] -1.04566266 3.4
#> [26,] -0.73398974 3.0
#> [27,] -0.73398974 3.4
#> [28,] -0.42231681 3.5
#> [29,] -0.42231681 3.4
#> [30,] -1.20149912 3.2
#> [31,] -1.04566266 3.1
#> [32,] -0.11064389 3.4
#> [33,] -0.42231681 4.1
#> [34,] 0.04519257 4.2
#> [35,] -0.88982620 3.1
#> [36,] -0.73398974 3.2
#> [37,] 0.04519257 3.5
#> [38,] -0.88982620 3.6
#> [39,] -1.66900851 3.0
#> [40,] -0.57815327 3.4
#> [41,] -0.73398974 3.5
#> [42,] -1.51317205 2.3
#> [43,] -1.66900851 3.2
#> [44,] -0.73398974 3.5
#> [45,] -0.57815327 3.8
#> [46,] -1.04566266 3.0
#> [47,] -0.57815327 3.8
#> [48,] -1.35733558 3.2
#> [49,] -0.26648035 3.7
#> [50,] -0.73398974 3.3
#> [51,] 2.38273950 3.2
#> [52,] 1.44772073 3.2
#> [53,] 2.22690304 3.1
#> [54,] 0.04519257 2.3
#> [55,] 1.60355719 2.8
#> [56,] 0.35686550 2.8
#> [57,] 1.29188427 3.3
#> [58,] -0.88982620 2.4
#> [59,] 1.75939366 2.9
#> [60,] -0.42231681 2.7
#> [61,] -0.73398974 2.0
#> [62,] 0.66853842 3.0
#> [63,] 0.82437488 2.2
#> [64,] 0.98021135 2.9
#> [65,] 0.20102904 2.9
#> [66,] 1.91523012 3.1
#> [67,] 0.20102904 3.0
#> [68,] 0.51270196 2.7
#> [69,] 1.13604781 2.2
#> [70,] 0.20102904 2.5
#> [71,] 0.66853842 3.2
#> [72,] 0.98021135 2.8
#> [73,] 1.29188427 2.5
#> [74,] 0.98021135 2.8
#> [75,] 1.44772073 2.9
#> [76,] 1.75939366 3.0
#> [77,] 2.07106658 2.8
#> [78,] 1.91523012 3.0
#> [79,] 0.82437488 2.9
#> [80,] 0.35686550 2.6
#> [81,] 0.04519257 2.4
#> [82,] 0.04519257 2.4
#> [83,] 0.51270196 2.7
#> [84,] 0.82437488 2.7
#> [85,] -0.11064389 3.0
#> [86,] 0.82437488 3.4
#> [87,] 1.91523012 3.1
#> [88,] 1.29188427 2.3
#> [89,] 0.20102904 3.0
#> [90,] 0.04519257 2.5
#> [91,] 0.04519257 2.6
#> [92,] 0.98021135 3.0
#> [93,] 0.51270196 2.6
#> [94,] -0.73398974 2.3
#> [95,] 0.20102904 2.7
#> [96,] 0.35686550 3.0
#> [97,] 0.35686550 2.9
#> [98,] 1.13604781 2.9
#> [99,] -0.57815327 2.5
#> [100,] 0.35686550 2.8
#>
#> $outcomes
#> NULL
#>
#> $extras
#> $extras$roles
#> NULL
forge(test, blueprint = processed$blueprint, outcomes = TRUE)
#> $predictors
#> 50 x 2 sparse Matrix of class "dgCMatrix"
#> Sepal.Length Sepal.Width
#> [1,] 1.2918843 3.3
#> [2,] 0.5127020 2.7
#> [3,] 2.5385760 3.0
#> [4,] 1.2918843 2.9
#> [5,] 1.6035572 3.0
#> [6,] 3.3177583 3.0
#> [7,] -0.8898262 2.5
#> [8,] 2.8502489 2.9
#> [9,] 1.9152301 2.5
#> [10,] 2.6944124 3.6
#> [11,] 1.6035572 3.2
#> [12,] 1.4477207 2.7
#> [13,] 2.0710666 3.0
#> [14,] 0.3568655 2.5
#> [15,] 0.5127020 2.8
#> [16,] 1.4477207 3.2
#> [17,] 1.6035572 3.0
#> [18,] 3.4735947 3.8
#> [19,] 3.4735947 2.6
#> [20,] 0.8243749 2.2
#> [21,] 2.2269030 3.2
#> [22,] 0.2010290 2.8
#> [23,] 3.4735947 2.8
#> [24,] 1.2918843 2.7
#> [25,] 1.9152301 3.3
#> [26,] 2.6944124 3.2
#> [27,] 1.1360478 2.8
#> [28,] 0.9802113 3.0
#> [29,] 1.4477207 2.8
#> [30,] 2.6944124 3.0
#> [31,] 3.0060854 2.8
#> [32,] 3.7852677 3.8
#> [33,] 1.4477207 2.8
#> [34,] 1.2918843 2.8
#> [35,] 0.9802113 2.6
#> [36,] 3.4735947 3.0
#> [37,] 1.2918843 3.4
#> [38,] 1.4477207 3.1
#> [39,] 0.8243749 3.0
#> [40,] 2.2269030 3.1
#> [41,] 1.9152301 3.1
#> [42,] 2.2269030 3.1
#> [43,] 0.5127020 2.7
#> [44,] 2.0710666 3.2
#> [45,] 1.9152301 3.3
#> [46,] 1.9152301 3.0
#> [47,] 1.2918843 2.5
#> [48,] 1.6035572 3.0
#> [49,] 1.1360478 3.4
#> [50,] 0.6685384 3.0
#>
#> $outcomes
#> # A tibble: 50 x 1
#> Species
#> <fct>
#> 1 virginica
#> 2 virginica
#> 3 virginica
#> 4 virginica
#> 5 virginica
#> 6 virginica
#> 7 virginica
#> 8 virginica
#> 9 virginica
#> 10 virginica
#> # … with 40 more rows
#>
#> $extras
#> $extras$roles
#> NULL Created on 2020-09-29 by the reprex package (v0.3.0.9001) A few things to note:
I set this: #' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format
#' of the processed predictors. I don't think there's much point to handling an explicit |
This might be more of a # See https://github.com/tidymodels/hardhat/pull/150
# remotes::install_github("tidymodels/hardhat@recipe-blueprint-composition")
# See https://github.com/tidymodels/parsnip/pull/373
# remotes::install_github("tidymodels/parsnip@sparsity")
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.9 ✓ rsample 0.0.8
#> ✓ dplyr 1.0.2 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.2
#> ✓ infer 0.5.2 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.2.0
#> ✓ parsnip 0.1.3.9000 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ───────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(hardhat)
data(ames)
ames <-
ames %>%
mutate(Sale_Price = log10(Sale_Price)) %>%
select(Sale_Price, Longitude, Latitude, Neighborhood)
rec <-
recipe(Sale_Price ~ ., data = ames) %>%
step_dummy(Neighborhood) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix") processed <- mold(rec, ames, blueprint = sparse_bp)
class(forge(ames, blueprint = processed$blueprint)$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix" lm_spec <- linear_reg(mixture = .5) %>% set_engine("glmnet")
lm_wflow <-
workflow() %>%
add_recipe(rec, blueprint = sparse_bp) %>%
add_model(lm_spec)
lm_wflow_1 <- .fit_pre(lm_wflow, data = ames)
lm_wflow_1$pre$mold$predictors
#> # A tibble: 2,930 x 29
#> Longitude Latitude Neighborhood_Co… Neighborhood_Ol… Neighborhood_Ed…
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.901 1.06 -0.317 -0.298 -0.266
#> 2 0.900 1.01 -0.317 -0.298 -0.266
#> 3 0.915 0.987 -0.317 -0.298 -0.266
#> 4 0.995 0.911 -0.317 -0.298 -0.266
#> 5 0.154 1.43 -0.317 -0.298 -0.266
#> 6 0.155 1.43 -0.317 -0.298 -0.266
#> 7 0.354 1.55 -0.317 -0.298 -0.266
#> 8 0.353 1.43 -0.317 -0.298 -0.266
#> 9 0.391 1.45 -0.317 -0.298 -0.266
#> 10 0.149 1.34 -0.317 -0.298 -0.266
#> # … with 2,920 more rows, and 24 more variables: Neighborhood_Somerset <dbl>,
#> # Neighborhood_Northridge_Heights <dbl>, Neighborhood_Gilbert <dbl>,
#> # Neighborhood_Sawyer <dbl>, Neighborhood_Northwest_Ames <dbl>,
#> # Neighborhood_Sawyer_West <dbl>, Neighborhood_Mitchell <dbl>,
#> # Neighborhood_Brookside <dbl>, Neighborhood_Crawford <dbl>,
#> # Neighborhood_Iowa_DOT_and_Rail_Road <dbl>, Neighborhood_Timberland <dbl>,
#> # Neighborhood_Northridge <dbl>, Neighborhood_Stone_Brook <dbl>,
#> # Neighborhood_South_and_West_of_Iowa_State_University <dbl>,
#> # Neighborhood_Clear_Creek <dbl>, Neighborhood_Meadow_Village <dbl>,
#> # Neighborhood_Briardale <dbl>, Neighborhood_Bloomington_Heights <dbl>,
#> # Neighborhood_Veenker <dbl>, Neighborhood_Northpark_Villa <dbl>,
#> # Neighborhood_Blueste <dbl>, Neighborhood_Greens <dbl>,
#> # Neighborhood_Green_Hills <dbl>, Neighborhood_Landmark <dbl> Created on 2020-09-30 by the reprex package (v0.3.0) Session infodevtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.2 (2020-06-22)
#> os macOS Catalina 10.15.5
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/New_York
#> date 2020-09-30
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.0)
#> backports 1.1.10 2020-09-15 [1] CRAN (R 4.0.2)
#> broom * 0.7.0 2020-07-09 [1] CRAN (R 4.0.0)
#> callr 3.4.4 2020-09-07 [1] CRAN (R 4.0.2)
#> class 7.3-17 2020-04-26 [1] CRAN (R 4.0.2)
#> cli 2.0.2 2020-02-28 [1] CRAN (R 4.0.0)
#> codetools 0.2-16 2018-12-24 [1] CRAN (R 4.0.2)
#> colorspace 1.4-1 2019-03-18 [1] CRAN (R 4.0.0)
#> crayon 1.3.4.9000 2020-08-18 [1] Github (r-lib/crayon@6b3f0c6)
#> desc 1.2.0 2018-05-01 [1] CRAN (R 4.0.0)
#> devtools 2.3.1 2020-07-21 [1] CRAN (R 4.0.2)
#> dials * 0.0.9 2020-09-16 [1] CRAN (R 4.0.2)
#> DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 4.0.0)
#> digest 0.6.25 2020-02-23 [1] CRAN (R 4.0.0)
#> dplyr * 1.0.2 2020-08-18 [1] CRAN (R 4.0.0)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.0)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 4.0.0)
#> fansi 0.4.1 2020-01-08 [1] CRAN (R 4.0.0)
#> foreach 1.5.0 2020-03-30 [1] CRAN (R 4.0.2)
#> fs 1.5.0 2020-07-31 [1] CRAN (R 4.0.2)
#> furrr 0.1.0 2018-05-16 [1] CRAN (R 4.0.0)
#> future 1.19.1 2020-09-22 [1] CRAN (R 4.0.2)
#> generics 0.0.2 2018-11-29 [1] CRAN (R 4.0.0)
#> ggplot2 * 3.3.2 2020-06-19 [1] CRAN (R 4.0.0)
#> globals 0.13.0 2020-09-17 [1] CRAN (R 4.0.2)
#> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2)
#> gower 0.2.2 2020-06-23 [1] CRAN (R 4.0.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.0.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.0.0)
#> hardhat * 0.1.4.9000 2020-09-30 [1] Github (tidymodels/hardhat@b763b1f)
#> highr 0.8 2019-03-20 [1] CRAN (R 4.0.0)
#> htmltools 0.5.0 2020-06-16 [1] CRAN (R 4.0.0)
#> infer * 0.5.2 2020-06-14 [1] CRAN (R 4.0.0)
#> ipred 0.9-9 2019-04-28 [1] CRAN (R 4.0.2)
#> iterators 1.0.12 2019-07-26 [1] CRAN (R 4.0.0)
#> knitr 1.30 2020-09-22 [1] CRAN (R 4.0.2)
#> lattice 0.20-41 2020-04-02 [1] CRAN (R 4.0.2)
#> lava 1.6.8 2020-09-26 [1] CRAN (R 4.0.2)
#> lhs 1.1.0 2020-09-29 [1] CRAN (R 4.0.2)
#> lifecycle 0.2.0 2020-03-06 [1] CRAN (R 4.0.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.0.0)
#> lubridate 1.7.9 2020-06-08 [1] CRAN (R 4.0.2)
#> magrittr 1.5 2014-11-22 [1] CRAN (R 4.0.0)
#> MASS 7.3-51.6 2020-04-26 [1] CRAN (R 4.0.2)
#> Matrix 1.2-18 2019-11-27 [1] CRAN (R 4.0.2)
#> memoise 1.1.0 2017-04-21 [1] CRAN (R 4.0.0)
#> modeldata * 0.0.2 2020-06-22 [1] CRAN (R 4.0.2)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.0.0)
#> nnet 7.3-14 2020-04-26 [1] CRAN (R 4.0.2)
#> parsnip * 0.1.3.9000 2020-09-30 [1] Github (tidymodels/parsnip@659b5ad)
#> pillar 1.4.6 2020-07-10 [1] CRAN (R 4.0.0)
#> pkgbuild 1.1.0 2020-07-13 [1] CRAN (R 4.0.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.0)
#> pkgload 1.1.0 2020-05-29 [1] CRAN (R 4.0.0)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.0.2)
#> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.0.0)
#> pROC 1.16.2 2020-03-19 [1] CRAN (R 4.0.2)
#> processx 3.4.4 2020-09-03 [1] CRAN (R 4.0.2)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)
#> ps 1.3.4 2020-08-11 [1] CRAN (R 4.0.2)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.0.0)
#> R6 2.4.1 2019-11-12 [1] CRAN (R 4.0.0)
#> Rcpp 1.0.5 2020-07-06 [1] CRAN (R 4.0.0)
#> recipes * 0.1.13 2020-06-23 [1] CRAN (R 4.0.2)
#> remotes 2.2.0 2020-07-21 [1] CRAN (R 4.0.2)
#> rlang 0.4.7 2020-07-09 [1] CRAN (R 4.0.0)
#> rmarkdown 2.3 2020-06-18 [1] CRAN (R 4.0.2)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.0.2)
#> rprojroot 1.3-2 2018-01-03 [1] CRAN (R 4.0.0)
#> rsample * 0.0.8 2020-09-23 [1] CRAN (R 4.0.2)
#> rstudioapi 0.11 2020-02-07 [1] CRAN (R 4.0.0)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.0.2)
#> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.0)
#> survival 3.1-12 2020-04-10 [1] CRAN (R 4.0.2)
#> testthat 2.3.2 2020-03-02 [1] CRAN (R 4.0.2)
#> tibble * 3.0.3 2020-07-10 [1] CRAN (R 4.0.0)
#> tidymodels * 0.1.1 2020-07-14 [1] CRAN (R 4.0.0)
#> tidyr * 1.1.2 2020-08-27 [1] CRAN (R 4.0.2)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.0.0)
#> tune * 0.1.1 2020-07-08 [1] CRAN (R 4.0.2)
#> usethis 1.9.0.9000 2020-09-30 [1] Github (r-lib/usethis@b993e83)
#> utf8 1.1.4 2018-05-24 [1] CRAN (R 4.0.0)
#> vctrs 0.3.4 2020-08-29 [1] CRAN (R 4.0.2)
#> withr 2.3.0 2020-09-22 [1] CRAN (R 4.0.2)
#> workflows * 0.2.0 2020-09-15 [1] CRAN (R 4.0.2)
#> xfun 0.17 2020-09-09 [1] CRAN (R 4.0.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.0)
#> yardstick * 0.0.7 2020-07-13 [1] CRAN (R 4.0.2)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some general feedback - will also discuss with Julia over Zoom.
The reason that Max's example with workflows didn't work is because mold()
is currently not being touched by these changes, only forge()
, but workflows uses the results of mold()
in the model fitting process.
I will discuss with Julia further, but my initial thoughts about this are that:
- we should apply this over all preprocessors
- we could probably simplify this by moving the composition bits to hardhat itself. I suggest a new verb called
recompose(data, composition = "tibble"/etc)
that does the composition bits that currently live at the end ofrecipes:::bake.recipe
To add it to all preprocessor types, we could place recompose()
in:
-
forge_recipe_default_process_predictors()
aftermaybe_add_intercept_column()
-
forge_formula_default_process_predictors()
after reattaching potential factor columns -
forge_xy_default_process_predictors()
aftermaybe_add_intercept_column()
-
mold_recipe_default_process_predictors()
aftermaybe_add_intercept_column()
-
mold_formula_default_process_predictors()
aftersimplify_terms()
-
mold_xy_default_process_predictors()
aftermaybe_add_intercept_column()
This would remove all of the changes to forge_recipe_default_process()
, there would be no need to try and work around baking only parts of the recipe, which I'm fairly certain actually just bakes the entire recipe on all columns, then just does the equivalent of a select()
at the end for the columns you requested
OK, I think this is getting there! 👍 I now have library(tidymodels)
library(hardhat)
data(ames, package = "modeldata")
x <- ames %>%
select(Longitude, Latitude, Year_Built)
y <- log10(ames$Sale_Price)
## works for xy
bp1 <- default_xy_blueprint(composition = "dgCMatrix")
x1 <- mold(x, y, blueprint = bp1)
class(x1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(x1$predictors)
#> [1] "Longitude" "Latitude" "Year_Built"
## works for formula
bp2 <- default_formula_blueprint(composition = "matrix")
x2 <- mold(log10(Sale_Price) ~ Longitude + Latitude + Neighborhood, ames, blueprint = bp2)
class(x2$predictors)
#> [1] "matrix" "array"
colnames(x2$predictors)
#> [1] "Longitude"
#> [2] "Latitude"
#> [3] "NeighborhoodNorth_Ames"
#> [4] "NeighborhoodCollege_Creek"
#> [5] "NeighborhoodOld_Town"
#> [6] "NeighborhoodEdwards"
#> [7] "NeighborhoodSomerset"
#> [8] "NeighborhoodNorthridge_Heights"
#> [9] "NeighborhoodGilbert"
#> [10] "NeighborhoodSawyer"
#> [11] "NeighborhoodNorthwest_Ames"
#> [12] "NeighborhoodSawyer_West"
#> [13] "NeighborhoodMitchell"
#> [14] "NeighborhoodBrookside"
#> [15] "NeighborhoodCrawford"
#> [16] "NeighborhoodIowa_DOT_and_Rail_Road"
#> [17] "NeighborhoodTimberland"
#> [18] "NeighborhoodNorthridge"
#> [19] "NeighborhoodStone_Brook"
#> [20] "NeighborhoodSouth_and_West_of_Iowa_State_University"
#> [21] "NeighborhoodClear_Creek"
#> [22] "NeighborhoodMeadow_Village"
#> [23] "NeighborhoodBriardale"
#> [24] "NeighborhoodBloomington_Heights"
#> [25] "NeighborhoodVeenker"
#> [26] "NeighborhoodNorthpark_Villa"
#> [27] "NeighborhoodBlueste"
#> [28] "NeighborhoodGreens"
#> [29] "NeighborhoodGreen_Hills"
#> [30] "NeighborhoodLandmark"
#> [31] "NeighborhoodHayden_Lake"
## can forge
xx1 <- forge(ames, blueprint = x1$blueprint)
class(xx1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
## workflow can get to the "recomposed" data
rec <-
recipe(Sale_Price ~ Longitude + Latitude + Neighborhood, data = ames) %>%
step_dummy(Neighborhood) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
lasso_spec <- linear_reg(mixture = 1) %>% set_engine("glmnet")
bp3 <- default_recipe_blueprint(composition = "dgCMatrix")
lasso_wflow <-
workflow() %>%
add_recipe(rec, blueprint = bp3) %>%
add_model(lasso_spec)
wflow_1 <- .fit_pre(lasso_wflow, data = ames)
class(wflow_1$pre$mold$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(wflow_1$pre$mold$predictors)
#> [1] "Longitude"
#> [2] "Latitude"
#> [3] "Neighborhood_College_Creek"
#> [4] "Neighborhood_Old_Town"
#> [5] "Neighborhood_Edwards"
#> [6] "Neighborhood_Somerset"
#> [7] "Neighborhood_Northridge_Heights"
#> [8] "Neighborhood_Gilbert"
#> [9] "Neighborhood_Sawyer"
#> [10] "Neighborhood_Northwest_Ames"
#> [11] "Neighborhood_Sawyer_West"
#> [12] "Neighborhood_Mitchell"
#> [13] "Neighborhood_Brookside"
#> [14] "Neighborhood_Crawford"
#> [15] "Neighborhood_Iowa_DOT_and_Rail_Road"
#> [16] "Neighborhood_Timberland"
#> [17] "Neighborhood_Northridge"
#> [18] "Neighborhood_Stone_Brook"
#> [19] "Neighborhood_South_and_West_of_Iowa_State_University"
#> [20] "Neighborhood_Clear_Creek"
#> [21] "Neighborhood_Meadow_Village"
#> [22] "Neighborhood_Briardale"
#> [23] "Neighborhood_Bloomington_Heights"
#> [24] "Neighborhood_Veenker"
#> [25] "Neighborhood_Northpark_Villa"
#> [26] "Neighborhood_Blueste"
#> [27] "Neighborhood_Greens"
#> [28] "Neighborhood_Green_Hills"
#> [29] "Neighborhood_Landmark" Created on 2020-09-30 by the reprex package (v0.3.0.9001) I added significant changes to the tests for the preprocessors FYI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a few final comments
@@ -26,6 +26,9 @@ | |||
#' prediction time? This information is used by the `clean` function in the | |||
#' `forge` function list, and is passed on to [scream()]. | |||
#' | |||
#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't used Matrix very much, but I know they have a ton of class types. Is it always a dgCMatrix
? The Matrix()
function can return a number of different things, and I'm not sure if setting sparse = TRUE
always guarantees a dgCMatrix class. If not, should the option just be "sparse_matrix"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure we should stick with the dgCMatrix class only. I just spent some time looking through the Matrix docs, and the dgCMatrix class is a compressed, sparse, column-oriented format which is the right kind for modeling. From the Matrix docs:
dgCMatrix
is the “standard” class for sparse numeric matrices in the Matrix package.
I haven't ever seen any of the other classes used in the wild in modeling, and things like the dfm type in quanteda inherit from dgCMatrix.
(The DocumentTermMatrix in tm inherits from the simple triplet matrix in slam, but I don't think we need to worry about that here.)
R/blueprint.R
Outdated
@@ -26,6 +26,9 @@ | |||
#' prediction time? This information is used by the `clean` function in the | |||
#' `forge` function list, and is passed on to [scream()]. | |||
#' | |||
#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format | |||
#' of the processed predictors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either `"tibble"`, `"matrix"`, or `"dgCMatrix"` for the format
of the processed predictors. If `"matrix"` or `"dgCMatrix"` are chosen,
all of the predictors must be numeric after the preprocessing method
has been applied, otherwise an error is thrown.
I don't think we have mentioned the numeric restriction anywhere, and this seemed like a good place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also mention that if dgCMatrix is used, the Matrix package must be installed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matrix is a "recommended" package, which I believe means it is part of regular R installations, right? It says that here FWIW.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did add rlang::if_installed
to send an informative error in case someone happens to not have this installed but I'm not sure how that would typically happen.
Co-authored-by: Davis Vaughan <davis@rstudio.com>
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
This is one of the first pieces needed for supporting sparse data structures in tidymodels.
Created on 2020-09-28 by the reprex package (v0.3.0.9001)