Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add composition to preprocessor blueprints #150

Merged
merged 22 commits into from Oct 5, 2020

Conversation

juliasilge
Copy link
Member

This is one of the first pieces needed for supporting sparse data structures in tidymodels.

library(hardhat)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step

train <- iris[1:100,]
test <- iris[101:150,]

rec <- recipe(Species ~ Sepal.Length + Sepal.Width, train) %>%
  step_log(Sepal.Length)

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
sparse_bp
#> Recipe blueprint: 
#>  
#> # Predictors: 0 
#>   # Outcomes: 0 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

processed <- mold(rec, train, blueprint = sparse_bp)
processed$blueprint
#> Recipe blueprint: 
#>  
#> # Predictors: 2 
#>   # Outcomes: 1 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

forge(train, blueprint = processed$blueprint)
#> $predictors
#> 83 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]     1.629241         3.5
#>  [2,]     1.589235         3.0
#>  [3,]     1.547563         3.2
#>  [4,]     1.526056         3.1
#>  [5,]     1.609438         3.6
#>  [6,]     1.686399         3.9
#>  [7,]     1.526056         3.4
#>  [8,]     1.609438         3.4
#>  [9,]     1.481605         2.9
#> [10,]     1.589235         3.1
#> [11,]     1.686399         3.7
#> [12,]     1.568616         3.4
#> [13,]     1.568616         3.0
#> [14,]     1.458615         3.0
#> [15,]     1.757858         4.0
#> [16,]     1.740466         4.4
#> [17,]     1.740466         3.8
#> [18,]     1.629241         3.8
#> [19,]     1.686399         3.4
#> [20,]     1.629241         3.7
#> [21,]     1.526056         3.6
#> [22,]     1.629241         3.3
#> [23,]     1.609438         3.0
#> [24,]     1.648659         3.5
#> [25,]     1.648659         3.4
#> [26,]     1.568616         3.1
#> [27,]     1.648659         4.1
#> [28,]     1.704748         4.2
#> [29,]     1.609438         3.2
#> [30,]     1.704748         3.5
#> [31,]     1.589235         3.6
#> [32,]     1.481605         3.0
#> [33,]     1.629241         3.4
#> [34,]     1.609438         3.5
#> [35,]     1.504077         2.3
#> [36,]     1.481605         3.2
#> [37,]     1.526056         3.2
#> [38,]     1.667707         3.7
#> [39,]     1.609438         3.3
#> [40,]     1.945910         3.2
#> [41,]     1.856298         3.2
#> [42,]     1.931521         3.1
#> [43,]     1.704748         2.3
#> [44,]     1.871802         2.8
#> [45,]     1.740466         2.8
#> [46,]     1.840550         3.3
#> [47,]     1.589235         2.4
#> [48,]     1.887070         2.9
#> [49,]     1.648659         2.7
#> [50,]     1.609438         2.0
#> [51,]     1.774952         3.0
#> [52,]     1.791759         2.2
#> [53,]     1.808289         2.9
#> [54,]     1.722767         2.9
#> [55,]     1.902108         3.1
#> [56,]     1.722767         3.0
#> [57,]     1.757858         2.7
#> [58,]     1.824549         2.2
#> [59,]     1.722767         2.5
#> [60,]     1.774952         3.2
#> [61,]     1.808289         2.8
#> [62,]     1.840550         2.5
#> [63,]     1.856298         2.9
#> [64,]     1.887070         3.0
#> [65,]     1.916923         2.8
#> [66,]     1.902108         3.0
#> [67,]     1.791759         2.9
#> [68,]     1.740466         2.6
#> [69,]     1.704748         2.4
#> [70,]     1.791759         2.7
#> [71,]     1.686399         3.0
#> [72,]     1.791759         3.4
#> [73,]     1.840550         2.3
#> [74,]     1.704748         2.5
#> [75,]     1.704748         2.6
#> [76,]     1.808289         3.0
#> [77,]     1.757858         2.6
#> [78,]     1.609438         2.3
#> [79,]     1.722767         2.7
#> [80,]     1.740466         3.0
#> [81,]     1.740466         2.9
#> [82,]     1.824549         2.9
#> [83,]     1.629241         2.5
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL
forge(test, blueprint = processed$blueprint)
#> $predictors
#> 44 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]     1.840550         3.3
#>  [2,]     1.757858         2.7
#>  [3,]     1.960095         3.0
#>  [4,]     1.840550         2.9
#>  [5,]     1.871802         3.0
#>  [6,]     2.028148         3.0
#>  [7,]     1.589235         2.5
#>  [8,]     1.987874         2.9
#>  [9,]     1.902108         2.5
#> [10,]     1.974081         3.6
#> [11,]     1.871802         3.2
#> [12,]     1.856298         2.7
#> [13,]     1.916923         3.0
#> [14,]     1.740466         2.5
#> [15,]     1.757858         2.8
#> [16,]     1.856298         3.2
#> [17,]     2.041220         3.8
#> [18,]     2.041220         2.6
#> [19,]     1.791759         2.2
#> [20,]     1.931521         3.2
#> [21,]     1.722767         2.8
#> [22,]     2.041220         2.8
#> [23,]     1.840550         2.7
#> [24,]     1.902108         3.3
#> [25,]     1.974081         3.2
#> [26,]     1.824549         2.8
#> [27,]     1.808289         3.0
#> [28,]     1.856298         2.8
#> [29,]     1.974081         3.0
#> [30,]     2.001480         2.8
#> [31,]     2.066863         3.8
#> [32,]     1.840550         2.8
#> [33,]     1.808289         2.6
#> [34,]     2.041220         3.0
#> [35,]     1.840550         3.4
#> [36,]     1.856298         3.1
#> [37,]     1.791759         3.0
#> [38,]     1.931521         3.1
#> [39,]     1.902108         3.1
#> [40,]     1.916923         3.2
#> [41,]     1.902108         3.0
#> [42,]     1.840550         2.5
#> [43,]     1.824549         3.4
#> [44,]     1.774952         3.0
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL

Created on 2020-09-28 by the reprex package (v0.3.0.9001)

@juliasilge
Copy link
Member Author

Related to tidymodels/tidymodels#42

@juliasilge
Copy link
Member Author

Closes #100 eventually

@juliasilge
Copy link
Member Author

I worked more on this today, and now during forge() the recipe is only baked one time on the predictors but everything else wires up correctly. 🎉

library(hardhat)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step

train <- iris[1:100,]
test <- iris[101:150,]

rec <- recipe(Species ~ Sepal.Length + Sepal.Width, train) %>%
  step_normalize(Sepal.Length)

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
sparse_bp
#> Recipe blueprint: 
#>  
#> # Predictors: 0 
#>   # Outcomes: 0 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

processed <- mold(rec, train, blueprint = sparse_bp)
processed$blueprint
#> Recipe blueprint: 
#>  
#> # Predictors: 2 
#>   # Outcomes: 1 
#>    Intercept: FALSE 
#> Novel Levels: FALSE 
#>  Composition: dgCMatrix

forge(train, blueprint = processed$blueprint)
#> $predictors
#> 100 x 2 sparse Matrix of class "dgCMatrix"
#>        Sepal.Length Sepal.Width
#>   [1,]  -0.57815327         3.5
#>   [2,]  -0.88982620         3.0
#>   [3,]  -1.20149912         3.2
#>   [4,]  -1.35733558         3.1
#>   [5,]  -0.73398974         3.6
#>   [6,]  -0.11064389         3.9
#>   [7,]  -1.35733558         3.4
#>   [8,]  -0.73398974         3.4
#>   [9,]  -1.66900851         2.9
#>  [10,]  -0.88982620         3.1
#>  [11,]  -0.11064389         3.7
#>  [12,]  -1.04566266         3.4
#>  [13,]  -1.04566266         3.0
#>  [14,]  -1.82484497         3.0
#>  [15,]   0.51270196         4.0
#>  [16,]   0.35686550         4.4
#>  [17,]  -0.11064389         3.9
#>  [18,]  -0.57815327         3.5
#>  [19,]   0.35686550         3.8
#>  [20,]  -0.57815327         3.8
#>  [21,]  -0.11064389         3.4
#>  [22,]  -0.57815327         3.7
#>  [23,]  -1.35733558         3.6
#>  [24,]  -0.57815327         3.3
#>  [25,]  -1.04566266         3.4
#>  [26,]  -0.73398974         3.0
#>  [27,]  -0.73398974         3.4
#>  [28,]  -0.42231681         3.5
#>  [29,]  -0.42231681         3.4
#>  [30,]  -1.20149912         3.2
#>  [31,]  -1.04566266         3.1
#>  [32,]  -0.11064389         3.4
#>  [33,]  -0.42231681         4.1
#>  [34,]   0.04519257         4.2
#>  [35,]  -0.88982620         3.1
#>  [36,]  -0.73398974         3.2
#>  [37,]   0.04519257         3.5
#>  [38,]  -0.88982620         3.6
#>  [39,]  -1.66900851         3.0
#>  [40,]  -0.57815327         3.4
#>  [41,]  -0.73398974         3.5
#>  [42,]  -1.51317205         2.3
#>  [43,]  -1.66900851         3.2
#>  [44,]  -0.73398974         3.5
#>  [45,]  -0.57815327         3.8
#>  [46,]  -1.04566266         3.0
#>  [47,]  -0.57815327         3.8
#>  [48,]  -1.35733558         3.2
#>  [49,]  -0.26648035         3.7
#>  [50,]  -0.73398974         3.3
#>  [51,]   2.38273950         3.2
#>  [52,]   1.44772073         3.2
#>  [53,]   2.22690304         3.1
#>  [54,]   0.04519257         2.3
#>  [55,]   1.60355719         2.8
#>  [56,]   0.35686550         2.8
#>  [57,]   1.29188427         3.3
#>  [58,]  -0.88982620         2.4
#>  [59,]   1.75939366         2.9
#>  [60,]  -0.42231681         2.7
#>  [61,]  -0.73398974         2.0
#>  [62,]   0.66853842         3.0
#>  [63,]   0.82437488         2.2
#>  [64,]   0.98021135         2.9
#>  [65,]   0.20102904         2.9
#>  [66,]   1.91523012         3.1
#>  [67,]   0.20102904         3.0
#>  [68,]   0.51270196         2.7
#>  [69,]   1.13604781         2.2
#>  [70,]   0.20102904         2.5
#>  [71,]   0.66853842         3.2
#>  [72,]   0.98021135         2.8
#>  [73,]   1.29188427         2.5
#>  [74,]   0.98021135         2.8
#>  [75,]   1.44772073         2.9
#>  [76,]   1.75939366         3.0
#>  [77,]   2.07106658         2.8
#>  [78,]   1.91523012         3.0
#>  [79,]   0.82437488         2.9
#>  [80,]   0.35686550         2.6
#>  [81,]   0.04519257         2.4
#>  [82,]   0.04519257         2.4
#>  [83,]   0.51270196         2.7
#>  [84,]   0.82437488         2.7
#>  [85,]  -0.11064389         3.0
#>  [86,]   0.82437488         3.4
#>  [87,]   1.91523012         3.1
#>  [88,]   1.29188427         2.3
#>  [89,]   0.20102904         3.0
#>  [90,]   0.04519257         2.5
#>  [91,]   0.04519257         2.6
#>  [92,]   0.98021135         3.0
#>  [93,]   0.51270196         2.6
#>  [94,]  -0.73398974         2.3
#>  [95,]   0.20102904         2.7
#>  [96,]   0.35686550         3.0
#>  [97,]   0.35686550         2.9
#>  [98,]   1.13604781         2.9
#>  [99,]  -0.57815327         2.5
#> [100,]   0.35686550         2.8
#> 
#> $outcomes
#> NULL
#> 
#> $extras
#> $extras$roles
#> NULL
forge(test, blueprint = processed$blueprint, outcomes = TRUE)
#> $predictors
#> 50 x 2 sparse Matrix of class "dgCMatrix"
#>       Sepal.Length Sepal.Width
#>  [1,]    1.2918843         3.3
#>  [2,]    0.5127020         2.7
#>  [3,]    2.5385760         3.0
#>  [4,]    1.2918843         2.9
#>  [5,]    1.6035572         3.0
#>  [6,]    3.3177583         3.0
#>  [7,]   -0.8898262         2.5
#>  [8,]    2.8502489         2.9
#>  [9,]    1.9152301         2.5
#> [10,]    2.6944124         3.6
#> [11,]    1.6035572         3.2
#> [12,]    1.4477207         2.7
#> [13,]    2.0710666         3.0
#> [14,]    0.3568655         2.5
#> [15,]    0.5127020         2.8
#> [16,]    1.4477207         3.2
#> [17,]    1.6035572         3.0
#> [18,]    3.4735947         3.8
#> [19,]    3.4735947         2.6
#> [20,]    0.8243749         2.2
#> [21,]    2.2269030         3.2
#> [22,]    0.2010290         2.8
#> [23,]    3.4735947         2.8
#> [24,]    1.2918843         2.7
#> [25,]    1.9152301         3.3
#> [26,]    2.6944124         3.2
#> [27,]    1.1360478         2.8
#> [28,]    0.9802113         3.0
#> [29,]    1.4477207         2.8
#> [30,]    2.6944124         3.0
#> [31,]    3.0060854         2.8
#> [32,]    3.7852677         3.8
#> [33,]    1.4477207         2.8
#> [34,]    1.2918843         2.8
#> [35,]    0.9802113         2.6
#> [36,]    3.4735947         3.0
#> [37,]    1.2918843         3.4
#> [38,]    1.4477207         3.1
#> [39,]    0.8243749         3.0
#> [40,]    2.2269030         3.1
#> [41,]    1.9152301         3.1
#> [42,]    2.2269030         3.1
#> [43,]    0.5127020         2.7
#> [44,]    2.0710666         3.2
#> [45,]    1.9152301         3.3
#> [46,]    1.9152301         3.0
#> [47,]    1.2918843         2.5
#> [48,]    1.6035572         3.0
#> [49,]    1.1360478         3.4
#> [50,]    0.6685384         3.0
#> 
#> $outcomes
#> # A tibble: 50 x 1
#>    Species  
#>    <fct>    
#>  1 virginica
#>  2 virginica
#>  3 virginica
#>  4 virginica
#>  5 virginica
#>  6 virginica
#>  7 virginica
#>  8 virginica
#>  9 virginica
#> 10 virginica
#> # … with 40 more rows
#> 
#> $extras
#> $extras$roles
#> NULL

Created on 2020-09-29 by the reprex package (v0.3.0.9001)

A few things to note:

  • The outcome is still a tibble, not a matrix or sparse matrix. We'll need those factors for classification models and such.
  • For situations where there are "extras", e.g. columns that are not predictors or outcomes, they are baked with the predictors. This means that they have to be numeric and the baking will error if they are categorical/nominal. Baking them with the outcomes does not work very well in any way I have tried, so I think this may just be a restriction for using a non-tibble composition.

I set this:

#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format
#' of the processed predictors.

I don't think there's much point to handling an explicit data.frame case, although it is a composition option in recipes.

@juliasilge juliasilge marked this pull request as ready for review September 30, 2020 04:17
@topepo
Copy link
Member

topepo commented Sep 30, 2020

This might be more of a workflows issue but .fit_pre() still generates a tibble

# See https://github.com/tidymodels/hardhat/pull/150
# remotes::install_github("tidymodels/hardhat@recipe-blueprint-composition")

# See https://github.com/tidymodels/parsnip/pull/373
# remotes::install_github("tidymodels/parsnip@sparsity")

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0          ✓ recipes   0.1.13    
#> ✓ dials     0.0.9          ✓ rsample   0.0.8     
#> ✓ dplyr     1.0.2          ✓ tibble    3.0.3     
#> ✓ ggplot2   3.3.2          ✓ tidyr     1.1.2     
#> ✓ infer     0.5.2          ✓ tune      0.1.1     
#> ✓ modeldata 0.0.2          ✓ workflows 0.2.0     
#> ✓ parsnip   0.1.3.9000     ✓ yardstick 0.0.7     
#> ✓ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
library(hardhat)

data(ames)

ames <- 
  ames %>% 
  mutate(Sale_Price = log10(Sale_Price)) %>% 
  select(Sale_Price, Longitude, Latitude, Neighborhood)

rec <- 
  recipe(Sale_Price ~ ., data = ames) %>% 
  step_dummy(Neighborhood) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")
processed <- mold(rec, ames, blueprint = sparse_bp)
class(forge(ames, blueprint = processed$blueprint)$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
lm_spec <- linear_reg(mixture = .5) %>% set_engine("glmnet")

lm_wflow <- 
  workflow() %>% 
  add_recipe(rec, blueprint = sparse_bp) %>% 
  add_model(lm_spec)

lm_wflow_1 <- .fit_pre(lm_wflow, data = ames)
lm_wflow_1$pre$mold$predictors
#> # A tibble: 2,930 x 29
#>    Longitude Latitude Neighborhood_Co… Neighborhood_Ol… Neighborhood_Ed…
#>        <dbl>    <dbl>            <dbl>            <dbl>            <dbl>
#>  1     0.901    1.06            -0.317           -0.298           -0.266
#>  2     0.900    1.01            -0.317           -0.298           -0.266
#>  3     0.915    0.987           -0.317           -0.298           -0.266
#>  4     0.995    0.911           -0.317           -0.298           -0.266
#>  5     0.154    1.43            -0.317           -0.298           -0.266
#>  6     0.155    1.43            -0.317           -0.298           -0.266
#>  7     0.354    1.55            -0.317           -0.298           -0.266
#>  8     0.353    1.43            -0.317           -0.298           -0.266
#>  9     0.391    1.45            -0.317           -0.298           -0.266
#> 10     0.149    1.34            -0.317           -0.298           -0.266
#> # … with 2,920 more rows, and 24 more variables: Neighborhood_Somerset <dbl>,
#> #   Neighborhood_Northridge_Heights <dbl>, Neighborhood_Gilbert <dbl>,
#> #   Neighborhood_Sawyer <dbl>, Neighborhood_Northwest_Ames <dbl>,
#> #   Neighborhood_Sawyer_West <dbl>, Neighborhood_Mitchell <dbl>,
#> #   Neighborhood_Brookside <dbl>, Neighborhood_Crawford <dbl>,
#> #   Neighborhood_Iowa_DOT_and_Rail_Road <dbl>, Neighborhood_Timberland <dbl>,
#> #   Neighborhood_Northridge <dbl>, Neighborhood_Stone_Brook <dbl>,
#> #   Neighborhood_South_and_West_of_Iowa_State_University <dbl>,
#> #   Neighborhood_Clear_Creek <dbl>, Neighborhood_Meadow_Village <dbl>,
#> #   Neighborhood_Briardale <dbl>, Neighborhood_Bloomington_Heights <dbl>,
#> #   Neighborhood_Veenker <dbl>, Neighborhood_Northpark_Villa <dbl>,
#> #   Neighborhood_Blueste <dbl>, Neighborhood_Greens <dbl>,
#> #   Neighborhood_Green_Hills <dbl>, Neighborhood_Landmark <dbl>

Created on 2020-09-30 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS Catalina 10.15.5      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/New_York            
#>  date     2020-09-30                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date       lib source                             
#>  assertthat    0.2.1      2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports     1.1.10     2020-09-15 [1] CRAN (R 4.0.2)                     
#>  broom       * 0.7.0      2020-07-09 [1] CRAN (R 4.0.0)                     
#>  callr         3.4.4      2020-09-07 [1] CRAN (R 4.0.2)                     
#>  class         7.3-17     2020-04-26 [1] CRAN (R 4.0.2)                     
#>  cli           2.0.2      2020-02-28 [1] CRAN (R 4.0.0)                     
#>  codetools     0.2-16     2018-12-24 [1] CRAN (R 4.0.2)                     
#>  colorspace    1.4-1      2019-03-18 [1] CRAN (R 4.0.0)                     
#>  crayon        1.3.4.9000 2020-08-18 [1] Github (r-lib/crayon@6b3f0c6)      
#>  desc          1.2.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools      2.3.1      2020-07-21 [1] CRAN (R 4.0.2)                     
#>  dials       * 0.0.9      2020-09-16 [1] CRAN (R 4.0.2)                     
#>  DiceDesign    1.8-1      2019-07-31 [1] CRAN (R 4.0.0)                     
#>  digest        0.6.25     2020-02-23 [1] CRAN (R 4.0.0)                     
#>  dplyr       * 1.0.2      2020-08-18 [1] CRAN (R 4.0.0)                     
#>  ellipsis      0.3.1      2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate      0.14       2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi         0.4.1      2020-01-08 [1] CRAN (R 4.0.0)                     
#>  foreach       1.5.0      2020-03-30 [1] CRAN (R 4.0.2)                     
#>  fs            1.5.0      2020-07-31 [1] CRAN (R 4.0.2)                     
#>  furrr         0.1.0      2018-05-16 [1] CRAN (R 4.0.0)                     
#>  future        1.19.1     2020-09-22 [1] CRAN (R 4.0.2)                     
#>  generics      0.0.2      2018-11-29 [1] CRAN (R 4.0.0)                     
#>  ggplot2     * 3.3.2      2020-06-19 [1] CRAN (R 4.0.0)                     
#>  globals       0.13.0     2020-09-17 [1] CRAN (R 4.0.2)                     
#>  glue          1.4.2      2020-08-27 [1] CRAN (R 4.0.2)                     
#>  gower         0.2.2      2020-06-23 [1] CRAN (R 4.0.0)                     
#>  GPfit         1.0-8      2019-02-08 [1] CRAN (R 4.0.0)                     
#>  gtable        0.3.0      2019-03-25 [1] CRAN (R 4.0.0)                     
#>  hardhat     * 0.1.4.9000 2020-09-30 [1] Github (tidymodels/hardhat@b763b1f)
#>  highr         0.8        2019-03-20 [1] CRAN (R 4.0.0)                     
#>  htmltools     0.5.0      2020-06-16 [1] CRAN (R 4.0.0)                     
#>  infer       * 0.5.2      2020-06-14 [1] CRAN (R 4.0.0)                     
#>  ipred         0.9-9      2019-04-28 [1] CRAN (R 4.0.2)                     
#>  iterators     1.0.12     2019-07-26 [1] CRAN (R 4.0.0)                     
#>  knitr         1.30       2020-09-22 [1] CRAN (R 4.0.2)                     
#>  lattice       0.20-41    2020-04-02 [1] CRAN (R 4.0.2)                     
#>  lava          1.6.8      2020-09-26 [1] CRAN (R 4.0.2)                     
#>  lhs           1.1.0      2020-09-29 [1] CRAN (R 4.0.2)                     
#>  lifecycle     0.2.0      2020-03-06 [1] CRAN (R 4.0.0)                     
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.0.0)                     
#>  lubridate     1.7.9      2020-06-08 [1] CRAN (R 4.0.2)                     
#>  magrittr      1.5        2014-11-22 [1] CRAN (R 4.0.0)                     
#>  MASS          7.3-51.6   2020-04-26 [1] CRAN (R 4.0.2)                     
#>  Matrix        1.2-18     2019-11-27 [1] CRAN (R 4.0.2)                     
#>  memoise       1.1.0      2017-04-21 [1] CRAN (R 4.0.0)                     
#>  modeldata   * 0.0.2      2020-06-22 [1] CRAN (R 4.0.2)                     
#>  munsell       0.5.0      2018-06-12 [1] CRAN (R 4.0.0)                     
#>  nnet          7.3-14     2020-04-26 [1] CRAN (R 4.0.2)                     
#>  parsnip     * 0.1.3.9000 2020-09-30 [1] Github (tidymodels/parsnip@659b5ad)
#>  pillar        1.4.6      2020-07-10 [1] CRAN (R 4.0.0)                     
#>  pkgbuild      1.1.0      2020-07-13 [1] CRAN (R 4.0.2)                     
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload       1.1.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  plyr          1.8.6      2020-03-03 [1] CRAN (R 4.0.2)                     
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.0.0)                     
#>  pROC          1.16.2     2020-03-19 [1] CRAN (R 4.0.2)                     
#>  processx      3.4.4      2020-09-03 [1] CRAN (R 4.0.2)                     
#>  prodlim       2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)                     
#>  ps            1.3.4      2020-08-11 [1] CRAN (R 4.0.2)                     
#>  purrr       * 0.3.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6            2.4.1      2019-11-12 [1] CRAN (R 4.0.0)                     
#>  Rcpp          1.0.5      2020-07-06 [1] CRAN (R 4.0.0)                     
#>  recipes     * 0.1.13     2020-06-23 [1] CRAN (R 4.0.2)                     
#>  remotes       2.2.0      2020-07-21 [1] CRAN (R 4.0.2)                     
#>  rlang         0.4.7      2020-07-09 [1] CRAN (R 4.0.0)                     
#>  rmarkdown     2.3        2020-06-18 [1] CRAN (R 4.0.2)                     
#>  rpart         4.1-15     2019-04-12 [1] CRAN (R 4.0.2)                     
#>  rprojroot     1.3-2      2018-01-03 [1] CRAN (R 4.0.0)                     
#>  rsample     * 0.0.8      2020-09-23 [1] CRAN (R 4.0.2)                     
#>  rstudioapi    0.11       2020-02-07 [1] CRAN (R 4.0.0)                     
#>  scales      * 1.1.1      2020-05-11 [1] CRAN (R 4.0.2)                     
#>  sessioninfo   1.1.1      2018-11-05 [1] CRAN (R 4.0.2)                     
#>  stringi       1.5.3      2020-09-09 [1] CRAN (R 4.0.2)                     
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.0.0)                     
#>  survival      3.1-12     2020-04-10 [1] CRAN (R 4.0.2)                     
#>  testthat      2.3.2      2020-03-02 [1] CRAN (R 4.0.2)                     
#>  tibble      * 3.0.3      2020-07-10 [1] CRAN (R 4.0.0)                     
#>  tidymodels  * 0.1.1      2020-07-14 [1] CRAN (R 4.0.0)                     
#>  tidyr       * 1.1.2      2020-08-27 [1] CRAN (R 4.0.2)                     
#>  tidyselect    1.1.0      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  timeDate      3043.102   2018-02-21 [1] CRAN (R 4.0.0)                     
#>  tune        * 0.1.1      2020-07-08 [1] CRAN (R 4.0.2)                     
#>  usethis       1.9.0.9000 2020-09-30 [1] Github (r-lib/usethis@b993e83)     
#>  utf8          1.1.4      2018-05-24 [1] CRAN (R 4.0.0)                     
#>  vctrs         0.3.4      2020-08-29 [1] CRAN (R 4.0.2)                     
#>  withr         2.3.0      2020-09-22 [1] CRAN (R 4.0.2)                     
#>  workflows   * 0.2.0      2020-09-15 [1] CRAN (R 4.0.2)                     
#>  xfun          0.17       2020-09-09 [1] CRAN (R 4.0.2)                     
#>  yaml          2.2.1      2020-02-01 [1] CRAN (R 4.0.0)                     
#>  yardstick   * 0.0.7      2020-07-13 [1] CRAN (R 4.0.2)                     
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library

Copy link
Member

@DavisVaughan DavisVaughan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some general feedback - will also discuss with Julia over Zoom.

The reason that Max's example with workflows didn't work is because mold() is currently not being touched by these changes, only forge(), but workflows uses the results of mold() in the model fitting process.

I will discuss with Julia further, but my initial thoughts about this are that:

  1. we should apply this over all preprocessors
  2. we could probably simplify this by moving the composition bits to hardhat itself. I suggest a new verb called recompose(data, composition = "tibble"/etc) that does the composition bits that currently live at the end of recipes:::bake.recipe

To add it to all preprocessor types, we could place recompose() in:

  • forge_recipe_default_process_predictors() after maybe_add_intercept_column()

  • forge_formula_default_process_predictors() after reattaching potential factor columns

  • forge_xy_default_process_predictors() after maybe_add_intercept_column()

  • mold_recipe_default_process_predictors() after maybe_add_intercept_column()

  • mold_formula_default_process_predictors() after simplify_terms()

  • mold_xy_default_process_predictors() after maybe_add_intercept_column()

This would remove all of the changes to forge_recipe_default_process(), there would be no need to try and work around baking only parts of the recipe, which I'm fairly certain actually just bakes the entire recipe on all columns, then just does the equivalent of a select() at the end for the columns you requested

R/blueprint-recipe.R Outdated Show resolved Hide resolved
R/blueprint-recipe.R Outdated Show resolved Hide resolved
R/blueprint-recipe-default.R Outdated Show resolved Hide resolved
R/blueprint-recipe-default.R Outdated Show resolved Hide resolved
R/print.R Show resolved Hide resolved
tests/testthat/test-forge-recipe.R Outdated Show resolved Hide resolved
tests/testthat/test-forge-recipe.R Outdated Show resolved Hide resolved
@juliasilge
Copy link
Member Author

OK, I think this is getting there! 👍 I now have mold() working as well, and also wired up to the important various bits:

library(tidymodels)
library(hardhat)

data(ames, package = "modeldata")
x <- ames %>%
  select(Longitude, Latitude, Year_Built)
y <- log10(ames$Sale_Price)


## works for xy
bp1 <- default_xy_blueprint(composition = "dgCMatrix")
x1 <- mold(x, y, blueprint = bp1)
class(x1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(x1$predictors)
#> [1] "Longitude"  "Latitude"   "Year_Built"

## works for formula
bp2 <- default_formula_blueprint(composition = "matrix")
x2 <- mold(log10(Sale_Price) ~ Longitude + Latitude + Neighborhood, ames, blueprint = bp2)
class(x2$predictors)
#> [1] "matrix" "array"
colnames(x2$predictors)
#>  [1] "Longitude"                                          
#>  [2] "Latitude"                                           
#>  [3] "NeighborhoodNorth_Ames"                             
#>  [4] "NeighborhoodCollege_Creek"                          
#>  [5] "NeighborhoodOld_Town"                               
#>  [6] "NeighborhoodEdwards"                                
#>  [7] "NeighborhoodSomerset"                               
#>  [8] "NeighborhoodNorthridge_Heights"                     
#>  [9] "NeighborhoodGilbert"                                
#> [10] "NeighborhoodSawyer"                                 
#> [11] "NeighborhoodNorthwest_Ames"                         
#> [12] "NeighborhoodSawyer_West"                            
#> [13] "NeighborhoodMitchell"                               
#> [14] "NeighborhoodBrookside"                              
#> [15] "NeighborhoodCrawford"                               
#> [16] "NeighborhoodIowa_DOT_and_Rail_Road"                 
#> [17] "NeighborhoodTimberland"                             
#> [18] "NeighborhoodNorthridge"                             
#> [19] "NeighborhoodStone_Brook"                            
#> [20] "NeighborhoodSouth_and_West_of_Iowa_State_University"
#> [21] "NeighborhoodClear_Creek"                            
#> [22] "NeighborhoodMeadow_Village"                         
#> [23] "NeighborhoodBriardale"                              
#> [24] "NeighborhoodBloomington_Heights"                    
#> [25] "NeighborhoodVeenker"                                
#> [26] "NeighborhoodNorthpark_Villa"                        
#> [27] "NeighborhoodBlueste"                                
#> [28] "NeighborhoodGreens"                                 
#> [29] "NeighborhoodGreen_Hills"                            
#> [30] "NeighborhoodLandmark"                               
#> [31] "NeighborhoodHayden_Lake"

## can forge
xx1 <- forge(ames, blueprint = x1$blueprint)
class(xx1$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"

## workflow can get to the "recomposed" data
rec <- 
  recipe(Sale_Price ~  Longitude + Latitude + Neighborhood, data = ames) %>% 
  step_dummy(Neighborhood) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lasso_spec <- linear_reg(mixture = 1) %>% set_engine("glmnet")
bp3 <- default_recipe_blueprint(composition = "dgCMatrix")

lasso_wflow <- 
  workflow() %>% 
  add_recipe(rec, blueprint = bp3) %>% 
  add_model(lasso_spec)

wflow_1 <- .fit_pre(lasso_wflow, data = ames)
class(wflow_1$pre$mold$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
colnames(wflow_1$pre$mold$predictors)
#>  [1] "Longitude"                                           
#>  [2] "Latitude"                                            
#>  [3] "Neighborhood_College_Creek"                          
#>  [4] "Neighborhood_Old_Town"                               
#>  [5] "Neighborhood_Edwards"                                
#>  [6] "Neighborhood_Somerset"                               
#>  [7] "Neighborhood_Northridge_Heights"                     
#>  [8] "Neighborhood_Gilbert"                                
#>  [9] "Neighborhood_Sawyer"                                 
#> [10] "Neighborhood_Northwest_Ames"                         
#> [11] "Neighborhood_Sawyer_West"                            
#> [12] "Neighborhood_Mitchell"                               
#> [13] "Neighborhood_Brookside"                              
#> [14] "Neighborhood_Crawford"                               
#> [15] "Neighborhood_Iowa_DOT_and_Rail_Road"                 
#> [16] "Neighborhood_Timberland"                             
#> [17] "Neighborhood_Northridge"                             
#> [18] "Neighborhood_Stone_Brook"                            
#> [19] "Neighborhood_South_and_West_of_Iowa_State_University"
#> [20] "Neighborhood_Clear_Creek"                            
#> [21] "Neighborhood_Meadow_Village"                         
#> [22] "Neighborhood_Briardale"                              
#> [23] "Neighborhood_Bloomington_Heights"                    
#> [24] "Neighborhood_Veenker"                                
#> [25] "Neighborhood_Northpark_Villa"                        
#> [26] "Neighborhood_Blueste"                                
#> [27] "Neighborhood_Greens"                                 
#> [28] "Neighborhood_Green_Hills"                            
#> [29] "Neighborhood_Landmark"

Created on 2020-09-30 by the reprex package (v0.3.0.9001)

I added significant changes to the tests for the preprocessors FYI.

@juliasilge juliasilge changed the title WIP: Add composition to recipe blueprint Add composition to preprocessor blueprints Oct 1, 2020
Copy link
Member

@DavisVaughan DavisVaughan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a few final comments

R/aaa.R Outdated Show resolved Hide resolved
@@ -26,6 +26,9 @@
#' prediction time? This information is used by the `clean` function in the
#' `forge` function list, and is passed on to [scream()].
#'
#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't used Matrix very much, but I know they have a ton of class types. Is it always a dgCMatrix? The Matrix() function can return a number of different things, and I'm not sure if setting sparse = TRUE always guarantees a dgCMatrix class. If not, should the option just be "sparse_matrix"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure we should stick with the dgCMatrix class only. I just spent some time looking through the Matrix docs, and the dgCMatrix class is a compressed, sparse, column-oriented format which is the right kind for modeling. From the Matrix docs:

dgCMatrix is the “standard” class for sparse numeric matrices in the Matrix package.

I haven't ever seen any of the other classes used in the wild in modeling, and things like the dfm type in quanteda inherit from dgCMatrix.

(The DocumentTermMatrix in tm inherits from the simple triplet matrix in slam, but I don't think we need to worry about that here.)

R/blueprint.R Outdated
@@ -26,6 +26,9 @@
#' prediction time? This information is used by the `clean` function in the
#' `forge` function list, and is passed on to [scream()].
#'
#' @param composition Either "tibble", "matrix", or "dgCMatrix" for the format
#' of the processed predictors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either `"tibble"`, `"matrix"`, or `"dgCMatrix"` for the format
of the processed predictors. If `"matrix"` or `"dgCMatrix"` are chosen,
all of the predictors must be numeric after the preprocessing method
has been applied, otherwise an error is thrown.

I don't think we have mentioned the numeric restriction anywhere, and this seemed like a good place

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also mention that if dgCMatrix is used, the Matrix package must be installed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix is a "recommended" package, which I believe means it is part of regular R installations, right? It says that here FWIW.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did add rlang::if_installed to send an informative error in case someone happens to not have this installed but I'm not sure how that would typically happen.

R/recompose.R Outdated Show resolved Hide resolved
tests/testthat/test-forge-formula.R Outdated Show resolved Hide resolved
tests/testthat/test-forge-formula.R Show resolved Hide resolved
R/recompose.R Show resolved Hide resolved
Co-authored-by: Davis Vaughan <davis@rstudio.com>
@DavisVaughan DavisVaughan merged commit f0996eb into master Oct 5, 2020
@DavisVaughan DavisVaughan deleted the recipe-blueprint-composition branch October 5, 2020 15:28
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jun 30, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants