Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

last_fit() and fit() give different probability predictions for the same workflow and data #300

Closed
AndrewKostandy opened this issue Oct 11, 2020 · 4 comments · Fixed by #323
Labels
bug an unexpected problem or unintended behavior

Comments

@AndrewKostandy
Copy link

When the same workflow is used with last_fit() and fit() on the same training and testing data, the probability predictions are different. Below is a reproducible example where I fit the same workflow 4 times using an identical seed. Twice with last_fit() and twice with fit(). The 2 probability predictions of last_fit() are identical to each other and the 2 probability predictions of fit() are identical to each other, however the probability predictions of last_fit() and fit() are very different (not a rounding problem).

Note that this problem happens whether or not parallel processing is used - using the below line for example:

# doParallel::registerDoParallel(cores = 7)

Reproducible example

library(tidyverse)
suppressMessages(library(tidymodels))
library(modeldata)
suppressMessages(library(themis))
library(embed)
data(attrition)

attrition <- janitor::clean_names(attrition)

attrition <- attrition %>% 
  mutate(attrition_class = fct_rev(attrition)) %>% 
  select(-attrition)

levels(attrition$attrition_class)
#> [1] "Yes" "No"

set.seed(1954)
attrition_split <- initial_split(attrition, strata = "attrition_class")
attrition_train <- training(attrition_split)
attrition_test <- testing(attrition_split)

xgboost_recipe <- 
  recipe(formula = attrition_class ~ ., data = attrition_train) %>% 
  step_upsample(attrition_class, seed = 936) %>% 
  step_discretize_xgb(age, daily_rate, distance_from_home, hourly_rate,
                      monthly_income, monthly_rate, total_working_years,
                      years_at_company, outcome = "attrition_class") %>% 
  step_novel(all_nominal(), -all_outcomes()) %>% 
  step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>% 
  step_zv(all_predictors()) 

xgboost_spec <- 
  boost_tree(trees = 462, min_n = 9, tree_depth = 15, learn_rate = 0.02160385, 
             loss_reduction = 0.6299069, sample_size = 0.1487263) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec)

common_seed <- 12345

# fit 2 models by last fit and 2 models by fit using the same seed
set.seed(common_seed)
last_fit1 <- last_fit(xgboost_workflow, attrition_split)

set.seed(common_seed)
last_fit2 <- last_fit(xgboost_workflow, attrition_split)

set.seed(common_seed)
fit1 <- fit(xgboost_workflow, attrition_train)

set.seed(common_seed)
fit2 <- fit(xgboost_workflow, attrition_train)

# get predictions of all 4 models
last_fit1_pred <- last_fit1[[5]][[1]]
last_fit2_pred <- last_fit2[[5]][[1]]

fit1_pred <- predict(fit1, attrition_test, type = "prob")

fit2_pred <- predict(fit2, attrition_test, type = "prob")

# see predictions of all 4 models
head(last_fit1_pred)
#> # A tibble: 6 x 5
#>   .pred_Yes .pred_No  .row .pred_class attrition_class
#>       <dbl>    <dbl> <int> <fct>       <fct>          
#> 1     0.243    0.757    13 No          No             
#> 2     0.255    0.745    16 No          No             
#> 3     0.366    0.634    22 No          Yes            
#> 4     0.788    0.212    27 Yes         Yes            
#> 5     0.237    0.763    28 No          No             
#> 6     0.239    0.761    41 No          No

head(last_fit2_pred)
#> # A tibble: 6 x 5
#>   .pred_Yes .pred_No  .row .pred_class attrition_class
#>       <dbl>    <dbl> <int> <fct>       <fct>          
#> 1     0.243    0.757    13 No          No             
#> 2     0.255    0.745    16 No          No             
#> 3     0.366    0.634    22 No          Yes            
#> 4     0.788    0.212    27 Yes         Yes            
#> 5     0.237    0.763    28 No          No             
#> 6     0.239    0.761    41 No          No

head(fit1_pred)
#> # A tibble: 6 x 2
#>   .pred_Yes .pred_No
#>       <dbl>    <dbl>
#> 1     0.351    0.649
#> 2     0.231    0.769
#> 3     0.324    0.676
#> 4     0.865    0.135
#> 5     0.158    0.842
#> 6     0.217    0.783

head(fit2_pred)
#> # A tibble: 6 x 2
#>   .pred_Yes .pred_No
#>       <dbl>    <dbl>
#> 1     0.351    0.649
#> 2     0.231    0.769
#> 3     0.324    0.676
#> 4     0.865    0.135
#> 5     0.158    0.842
#> 6     0.217    0.783

# last_fit prediction probabilities identical to each other
identical(last_fit1_pred$.pred_Yes, last_fit2_pred$.pred_Yes)
#> [1] TRUE

# fit prediction probabilities identical to each other
identical(fit1_pred$.pred_Yes, fit2_pred$.pred_Yes)
#> [1] TRUE

# last_fit prediction probabilities NOT identical to fit prediction probabilities
identical(last_fit1_pred$.pred_Yes, fit1_pred$.pred_Yes)
#> [1] FALSE

Created on 2020-10-11 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS Catalina 10.15.7      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_CA.UTF-8                 
#>  ctype    en_CA.UTF-8                 
#>  tz       America/Toronto             
#>  date     2020-10-11                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date       lib source        
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.0.2)
#>  backports      1.1.10     2020-09-15 [1] CRAN (R 4.0.2)
#>  base64enc      0.1-3      2015-07-28 [1] CRAN (R 4.0.2)
#>  bayesplot      1.7.2      2020-05-28 [1] CRAN (R 4.0.2)
#>  BBmisc         1.11       2017-03-10 [1] CRAN (R 4.0.2)
#>  blob           1.2.1      2020-01-20 [1] CRAN (R 4.0.2)
#>  boot           1.3-25     2020-04-26 [1] CRAN (R 4.0.2)
#>  broom        * 0.7.1      2020-10-02 [1] CRAN (R 4.0.2)
#>  callr          3.5.0      2020-10-08 [1] CRAN (R 4.0.2)
#>  cellranger     1.1.0      2016-07-27 [1] CRAN (R 4.0.2)
#>  checkmate      2.0.0      2020-02-06 [1] CRAN (R 4.0.2)
#>  class          7.3-17     2020-04-26 [1] CRAN (R 4.0.2)
#>  cli            2.0.2      2020-02-28 [1] CRAN (R 4.0.2)
#>  codetools      0.2-16     2018-12-24 [1] CRAN (R 4.0.2)
#>  colorspace     1.4-1      2019-03-18 [1] CRAN (R 4.0.2)
#>  colourpicker   1.1.0      2020-09-14 [1] CRAN (R 4.0.2)
#>  crayon         1.3.4      2017-09-16 [1] CRAN (R 4.0.2)
#>  crosstalk      1.1.0.1    2020-03-13 [1] CRAN (R 4.0.2)
#>  curl           4.3        2019-12-02 [1] CRAN (R 4.0.1)
#>  data.table     1.13.0     2020-07-24 [1] CRAN (R 4.0.2)
#>  DBI            1.1.0      2019-12-15 [1] CRAN (R 4.0.2)
#>  dbplyr         1.4.4      2020-05-27 [1] CRAN (R 4.0.2)
#>  desc           1.2.0      2018-05-01 [1] CRAN (R 4.0.2)
#>  devtools       2.3.2      2020-09-18 [1] CRAN (R 4.0.2)
#>  dials        * 0.0.9      2020-09-16 [1] CRAN (R 4.0.2)
#>  DiceDesign     1.8-1      2019-07-31 [1] CRAN (R 4.0.2)
#>  digest         0.6.25     2020-02-23 [1] CRAN (R 4.0.2)
#>  doParallel     1.0.15     2019-08-02 [1] CRAN (R 4.0.2)
#>  dplyr        * 1.0.2      2020-08-18 [1] CRAN (R 4.0.2)
#>  DT             0.15       2020-08-05 [1] CRAN (R 4.0.2)
#>  dygraphs       1.1.1.6    2018-07-11 [1] CRAN (R 4.0.2)
#>  ellipsis       0.3.1      2020-05-15 [1] CRAN (R 4.0.2)
#>  embed        * 0.1.1      2020-07-03 [1] CRAN (R 4.0.2)
#>  evaluate       0.14       2019-05-28 [1] CRAN (R 4.0.1)
#>  fansi          0.4.1      2020-01-08 [1] CRAN (R 4.0.2)
#>  fastmap        1.0.1      2019-10-08 [1] CRAN (R 4.0.2)
#>  fastmatch      1.1-0      2017-01-28 [1] CRAN (R 4.0.2)
#>  FNN            1.1.3      2019-02-15 [1] CRAN (R 4.0.2)
#>  forcats      * 0.5.0      2020-03-01 [1] CRAN (R 4.0.2)
#>  foreach        1.5.0      2020-03-30 [1] CRAN (R 4.0.2)
#>  fs             1.5.0      2020-07-31 [1] CRAN (R 4.0.2)
#>  furrr          0.1.0      2018-05-16 [1] CRAN (R 4.0.2)
#>  future         1.19.1     2020-09-22 [1] CRAN (R 4.0.2)
#>  generics       0.0.2      2018-11-29 [1] CRAN (R 4.0.2)
#>  ggplot2      * 3.3.2      2020-06-19 [1] CRAN (R 4.0.2)
#>  ggridges       0.5.2      2020-01-12 [1] CRAN (R 4.0.2)
#>  globals        0.13.1     2020-10-11 [1] CRAN (R 4.0.2)
#>  glue           1.4.2      2020-08-27 [1] CRAN (R 4.0.2)
#>  gower          0.2.2      2020-06-23 [1] CRAN (R 4.0.2)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.0.2)
#>  gridExtra      2.3        2017-09-09 [1] CRAN (R 4.0.2)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.0.2)
#>  gtools         3.8.2      2020-03-31 [1] CRAN (R 4.0.2)
#>  hardhat        0.1.4      2020-07-02 [1] CRAN (R 4.0.2)
#>  haven          2.3.1      2020-06-01 [1] CRAN (R 4.0.2)
#>  highr          0.8        2019-03-20 [1] CRAN (R 4.0.2)
#>  hms            0.5.3      2020-01-08 [1] CRAN (R 4.0.2)
#>  htmltools      0.5.0      2020-06-16 [1] CRAN (R 4.0.2)
#>  htmlwidgets    1.5.2      2020-10-03 [1] CRAN (R 4.0.2)
#>  httpuv         1.5.4      2020-06-06 [1] CRAN (R 4.0.2)
#>  httr           1.4.2      2020-07-20 [1] CRAN (R 4.0.2)
#>  igraph         1.2.6      2020-10-06 [1] CRAN (R 4.0.2)
#>  infer        * 0.5.3      2020-07-14 [1] CRAN (R 4.0.2)
#>  inline         0.3.16     2020-09-06 [1] CRAN (R 4.0.2)
#>  ipred          0.9-9      2019-04-28 [1] CRAN (R 4.0.2)
#>  iterators      1.0.12     2019-07-26 [1] CRAN (R 4.0.2)
#>  janitor        2.0.1      2020-04-12 [1] CRAN (R 4.0.2)
#>  jsonlite       1.7.1      2020-09-07 [1] CRAN (R 4.0.2)
#>  keras          2.3.0.0    2020-05-19 [1] CRAN (R 4.0.2)
#>  knitr          1.30       2020-09-22 [1] CRAN (R 4.0.2)
#>  later          1.1.0.1    2020-06-05 [1] CRAN (R 4.0.2)
#>  lattice        0.20-41    2020-04-02 [1] CRAN (R 4.0.2)
#>  lava           1.6.8      2020-09-26 [1] CRAN (R 4.0.2)
#>  lhs            1.1.1      2020-10-05 [1] CRAN (R 4.0.2)
#>  lifecycle      0.2.0      2020-03-06 [1] CRAN (R 4.0.2)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.0.2)
#>  lme4           1.1-23     2020-04-07 [1] CRAN (R 4.0.1)
#>  loo            2.3.1      2020-07-14 [1] CRAN (R 4.0.2)
#>  lubridate      1.7.9      2020-06-08 [1] CRAN (R 4.0.2)
#>  magrittr       1.5        2014-11-22 [1] CRAN (R 4.0.2)
#>  markdown       1.1        2019-08-07 [1] CRAN (R 4.0.2)
#>  MASS           7.3-53     2020-09-09 [1] CRAN (R 4.0.2)
#>  Matrix         1.2-18     2019-11-27 [1] CRAN (R 4.0.2)
#>  matrixStats    0.57.0     2020-09-25 [1] CRAN (R 4.0.2)
#>  memoise        1.1.0      2017-04-21 [1] CRAN (R 4.0.2)
#>  mime           0.9        2020-02-04 [1] CRAN (R 4.0.2)
#>  miniUI         0.1.1.1    2018-05-18 [1] CRAN (R 4.0.2)
#>  minqa          1.2.4      2014-10-09 [1] CRAN (R 4.0.2)
#>  mlr            2.18.0     2020-10-05 [1] CRAN (R 4.0.2)
#>  modeldata    * 0.0.2      2020-06-22 [1] CRAN (R 4.0.2)
#>  modelr         0.1.8      2020-05-19 [1] CRAN (R 4.0.2)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.0.2)
#>  nlme           3.1-149    2020-08-23 [1] CRAN (R 4.0.2)
#>  nloptr         1.2.2.2    2020-07-02 [1] CRAN (R 4.0.2)
#>  nnet           7.3-14     2020-04-26 [1] CRAN (R 4.0.2)
#>  parallelMap    1.5.0      2020-03-26 [1] CRAN (R 4.0.2)
#>  ParamHelpers   1.14       2020-03-24 [1] CRAN (R 4.0.2)
#>  parsnip      * 0.1.3      2020-08-04 [1] CRAN (R 4.0.2)
#>  pillar         1.4.6      2020-07-10 [1] CRAN (R 4.0.2)
#>  pkgbuild       1.1.0      2020-07-13 [1] CRAN (R 4.0.2)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.0.2)
#>  pkgload        1.1.0      2020-05-29 [1] CRAN (R 4.0.2)
#>  plyr           1.8.6      2020-03-03 [1] CRAN (R 4.0.2)
#>  prettyunits    1.1.1      2020-01-24 [1] CRAN (R 4.0.2)
#>  pROC           1.16.2     2020-03-19 [1] CRAN (R 4.0.2)
#>  processx       3.4.4      2020-09-03 [1] CRAN (R 4.0.2)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.0.2)
#>  promises       1.1.1      2020-06-09 [1] CRAN (R 4.0.2)
#>  ps             1.4.0      2020-10-07 [1] CRAN (R 4.0.2)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.0.2)
#>  R6             2.4.1      2019-11-12 [1] CRAN (R 4.0.2)
#>  RANN           2.6.1      2019-01-08 [1] CRAN (R 4.0.2)
#>  Rcpp           1.0.5      2020-07-06 [1] CRAN (R 4.0.2)
#>  RcppParallel   5.0.2      2020-06-24 [1] CRAN (R 4.0.2)
#>  readr        * 1.4.0      2020-10-05 [1] CRAN (R 4.0.2)
#>  readxl         1.3.1      2019-03-13 [1] CRAN (R 4.0.2)
#>  recipes      * 0.1.13     2020-06-23 [1] CRAN (R 4.0.2)
#>  remotes        2.2.0      2020-07-21 [1] CRAN (R 4.0.2)
#>  reprex         0.3.0      2019-05-16 [1] CRAN (R 4.0.2)
#>  reshape2       1.4.4      2020-04-09 [1] CRAN (R 4.0.2)
#>  reticulate     1.16       2020-05-27 [1] CRAN (R 4.0.2)
#>  rlang          0.4.7      2020-07-09 [1] CRAN (R 4.0.2)
#>  rmarkdown      2.4        2020-09-30 [1] CRAN (R 4.0.2)
#>  ROSE           0.0-3      2014-07-15 [1] CRAN (R 4.0.2)
#>  rpart          4.1-15     2019-04-12 [1] CRAN (R 4.0.2)
#>  rprojroot      1.3-2      2018-01-03 [1] CRAN (R 4.0.2)
#>  rsample      * 0.0.8      2020-09-23 [1] CRAN (R 4.0.2)
#>  rsconnect      0.8.16     2019-12-13 [1] CRAN (R 4.0.2)
#>  rstan          2.21.2     2020-07-27 [1] CRAN (R 4.0.2)
#>  rstanarm       2.21.1     2020-07-20 [1] CRAN (R 4.0.2)
#>  rstantools     2.1.1      2020-07-06 [1] CRAN (R 4.0.2)
#>  rstudioapi     0.11       2020-02-07 [1] CRAN (R 4.0.2)
#>  rvest          0.3.6      2020-07-25 [1] CRAN (R 4.0.2)
#>  scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.0.2)
#>  sessioninfo    1.1.1      2018-11-05 [1] CRAN (R 4.0.2)
#>  shiny          1.5.0      2020-06-23 [1] CRAN (R 4.0.2)
#>  shinyjs        2.0.0      2020-09-09 [1] CRAN (R 4.0.2)
#>  shinystan      2.5.0      2018-05-01 [1] CRAN (R 4.0.2)
#>  shinythemes    1.1.2      2018-11-06 [1] CRAN (R 4.0.2)
#>  snakecase      0.11.0     2019-05-25 [1] CRAN (R 4.0.2)
#>  StanHeaders    2.21.0-6   2020-08-16 [1] CRAN (R 4.0.2)
#>  statmod        1.4.34     2020-02-17 [1] CRAN (R 4.0.2)
#>  stringi        1.5.3      2020-09-09 [1] CRAN (R 4.0.2)
#>  stringr      * 1.4.0      2019-02-10 [1] CRAN (R 4.0.2)
#>  survival       3.2-7      2020-09-28 [1] CRAN (R 4.0.2)
#>  tensorflow     2.2.0      2020-05-11 [1] CRAN (R 4.0.2)
#>  testthat       2.3.2      2020-03-02 [1] CRAN (R 4.0.2)
#>  tfruns         1.4        2018-08-25 [1] CRAN (R 4.0.2)
#>  themis       * 0.1.2      2020-08-14 [1] CRAN (R 4.0.2)
#>  threejs        0.3.3      2020-01-21 [1] CRAN (R 4.0.2)
#>  tibble       * 3.0.3      2020-07-10 [1] CRAN (R 4.0.2)
#>  tidymodels   * 0.1.1      2020-07-14 [1] CRAN (R 4.0.2)
#>  tidyr        * 1.1.2      2020-08-27 [1] CRAN (R 4.0.2)
#>  tidyselect     1.1.0      2020-05-11 [1] CRAN (R 4.0.2)
#>  tidyverse    * 1.3.0      2019-11-21 [1] CRAN (R 4.0.2)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.0.2)
#>  tune         * 0.1.1      2020-07-08 [1] CRAN (R 4.0.2)
#>  unbalanced     2.0        2015-06-26 [1] CRAN (R 4.0.2)
#>  usethis        1.6.3      2020-09-17 [1] CRAN (R 4.0.2)
#>  utf8           1.1.4      2018-05-24 [1] CRAN (R 4.0.2)
#>  uwot           0.1.8      2020-03-16 [1] CRAN (R 4.0.2)
#>  V8             3.2.0      2020-06-19 [1] CRAN (R 4.0.2)
#>  vctrs          0.3.4      2020-08-29 [1] CRAN (R 4.0.2)
#>  whisker        0.4        2019-08-28 [1] CRAN (R 4.0.2)
#>  withr          2.3.0      2020-09-22 [1] CRAN (R 4.0.2)
#>  workflows    * 0.2.1      2020-10-08 [1] CRAN (R 4.0.2)
#>  xfun           0.18       2020-09-29 [1] CRAN (R 4.0.2)
#>  xgboost        1.2.0.1    2020-09-02 [1] CRAN (R 4.0.2)
#>  xml2           1.3.2      2020-04-23 [1] CRAN (R 4.0.2)
#>  xtable         1.8-4      2019-04-21 [1] CRAN (R 4.0.2)
#>  xts            0.12.1     2020-09-09 [1] CRAN (R 4.0.2)
#>  yaml           2.2.1      2020-02-01 [1] CRAN (R 4.0.2)
#>  yardstick    * 0.0.7      2020-07-13 [1] CRAN (R 4.0.2)
#>  zeallot        0.1.0      2018-01-28 [1] CRAN (R 4.0.2)
#>  zoo            1.8-8      2020-05-02 [1] CRAN (R 4.0.2)
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
@DavisVaughan
Copy link
Member

@topepo this has to do with the explicit seed that we now set in tune from #275

i.e. the one generated here:

resamples <- dplyr::mutate(resamples, .seed = sample.int(10^5, nrow(resamples)))

and set in the iteration here:

set.seed(resamples$.seed[[rs_iter]])

if you remove the setting of that seed, the results are identical (i checked)

@juliasilge
Copy link
Member

I'm going to move this to tune.

@juliasilge juliasilge transferred this issue from tidymodels/parsnip Oct 13, 2020
@juliasilge juliasilge added the bug an unexpected problem or unintended behavior label Oct 13, 2020
topepo added a commit that referenced this issue Nov 2, 2020
@topepo
Copy link
Member

topepo commented Nov 3, 2020

About to submit a PR to fix this but want to document some technical notes on this...

fit_resamples() and tune_grid() generate a set of artificial RNG seeds for each resample. During processing, these seeds are set for each resample when the grid values are being processed. This is a low-level approach to make sure the random numbers are reproducible when the calculations are run in parallel. Unfortunately, it uses random numbers and this makes the RNG state at the time of the model fit to be different from the state in the main R process. That's the reason that last_fit() generates different results.

Even if we didn't do that and the user has a parallel backend registered, it is very possible that the worker process won't have the same random numbers as the main process.

The solution used in commit dce5586 is to avoid generation of any random numbers when a single model is fit with a single resample (= last_fit()). It also turns off parallel processing when this criterion is met.

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 6, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants