In [1]:
suppressPackageStartupMessages({
    suppressWarnings({
        library(data.table)
        library(tidymodels)
        library(tidyverse)
        library(treesnip)
    })
})

In [2]:
set.seed(12)

spl <- initial_split(iris,prop = .6)

train_set <- spl %>% training %>% as.data.table

test_set <- spl %>% testing %>% as.data.table

In [3]:
resample <- vfold_cv(train_set,v = 5)

In [4]:
cb_spc <- boost_tree() %>%
set_mode('classification') %>%
set_engine('catboost')

xgb_spc <- boost_tree() %>%
set_mode('classification') %>%
set_engine('xgboost')

lgb_spc <- boost_tree() %>%
set_mode('classification') %>%
set_engine('lightgbm')

rngr_spc <- rand_forest() %>%
set_mode('classification') %>%
set_engine('ranger')

rforest_spc <- rand_forest() %>%
set_mode('classification') %>%
set_engine('randomForest')

nn_spc <- mlp() %>%
set_mode('classification') %>%
set_engine('nnet')

brl_spc <- mlp() %>%
set_mode('classification') %>%
set_engine('brulee')

In [5]:
dummied <- recipe(Species ~ .,train_set) %>%
step_dummy(all_nominal_predictors())

In [6]:
normalized <- 
   workflow_set(
      preproc = list(dummied = dummied), 
      models = list(cb = cb_spc, xgb = xgb_spc, 
                    lgb = lgb_spc, rngr = rngr_spc,rforest = rforest_spc,nn = nn_spc,
                    brl = brl_spc)
   )

In [7]:
grid_ctrl <-
   control_resamples(
      save_pred = TRUE,
      save_workflow = TRUE,
      verbose = TRUE
   )

grid_results <-
   normalized %>%
   workflow_map(
      seed = 1503,
      resamples = resample,
      control = grid_ctrl
   )

[34mi[39m [30mFold1: preprocessor 1/1[39m

[32m✓[39m [30mFold1: preprocessor 1/1[39m

[34mi[39m [30mFold1: preprocessor 1/1, model 1/1[39m

[32m✓[39m [30mFold1: preprocessor 1/1, model 1/1[39m

[34mi[39m [30mFold1: preprocessor 1/1, model 1/1 (predictions)[39m

[34mi[39m [30mFold2: preprocessor 1/1[39m

[32m✓[39m [30mFold2: preprocessor 1/1[39m

[34mi[39m [30mFold2: preprocessor 1/1, model 1/1[39m

[32m✓[39m [30mFold2: preprocessor 1/1, model 1/1[39m

[34mi[39m [30mFold2: preprocessor 1/1, model 1/1 (predictions)[39m

[34mi[39m [30mFold3: preprocessor 1/1[39m

[32m✓[39m [30mFold3: preprocessor 1/1[39m

[34mi[39m [30mFold3: preprocessor 1/1, model 1/1[39m

[32m✓[39m [30mFold3: preprocessor 1/1, model 1/1[39m

[34mi[39m [30mFold3: preprocessor 1/1, model 1/1 (predictions)[39m

[34mi[39m [30mFold4: preprocessor 1/1[39m

[32m✓[39m [30mFold4: preprocessor 1/1[39m

[34mi[39m [30mFold4: preprocessor 1/1, model 1/1[39m

[

In [None]:
autoplot(
   grid_results,
   rank_metric = "accuracy",  # <- how to order models
   metric = "rmse",       # <- which metric to visualize
   select_best = TRUE     # <- one point per workflow
) +
   geom_text(aes(y = mean - 1/2, label = wflow_id), angle = 90, hjust = 1) +
   lims(y = c(3.5, 9.5)) +
   theme(legend.position = "none")