We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Function:
hai_auto_knn <- function(.data, .rec_obj, .splits_obj = NULL, .rsamp_obj = NULL, .tune = TRUE, .grid_size = 10, .num_cores = 1, .best_metric = "rmse", .model_type = "regression"){ # Tidyeval ---- grid_size <- as.numeric(.grid_size) num_cores <- as.numeric(.num_cores) best_metric <- as.character(.best_metric) data_tbl <- dplyr::as_tibble(.data) splits <- .splits_obj rec_obj <- .rec_obj rsamp_obj <- .rsamp_obj model_type <- as.character(.model_type) # Checks ---- if (!inherits(x = splits, what = "rsplit")){ rlang::abort( message = "'.rsamp_obj' must have a class of 'rsplit', use the rsample package.", use_cli_format = TRUE ) } if (!inherits(x = rec_obj, what = "recipe")){ rlang::abort( message = "'.rec_obj' must have a class of 'recipe'." ) } if (!model_type %in% c("regression","classification")){ rlang::abort( message = paste0( "You chose a mode of: '", model_type, "' this is unsupported. Choose from either 'regression' or 'classification'." ), use_cli_format = TRUE ) } if (!inherits(x = rsamp_obj, what = "rset") && !is.null(rsamp_obj)){ rlang::abort( message = "The '.rsamp_obj' argument must either be NULL or an object of calss 'rset'.", use_cli_format = TRUE ) } if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){ rlang::abort( message = "The '.splits_obj' argument must either be NULL or an object of class 'rsplit'", use_cli_format = TRUE ) } # Set default metric set ---- if (model_type == "classification"){ ms <- hai_default_classification_metric_set() } else { ms <- hai_default_regression_metric_set() } # Get splits if not then create if (is.null(splits)){ splits <- rsample::initial_split(data = data_tbl) } else { splits <- splits } # Tune/Spec ---- if (.tune){ # Model Specification model_spec <- parsnip::nearest_neighbor( neighbors = tune::tune(), weight_func = tune::tune(), dist_power = tune::tune() ) } else { model_spec <- parsnip::nearest_neighbor() } # Model Specification ---- model_spec <- model_spec %>% parsnip::set_mode(mode = model_type) %>% parsnip::set_engine(engine = "kknn") # Workflow ---- wflw <- workflows::workflow() %>% workflows::add_recipe(rec_obj) %>% workflows::add_model(model_spec) # Tuning Grid --- if (.tune){ # Make tuning grid tuning_grid_spec <- dials::grid_latin_hypercube( hardhat::extract_parameter_set_dials(model_spec), size = grid_size ) # Cross validation object if (is.null(rsamp_obj)){ cv_obj <- rsample::mc_cv( data = rsample::training(splits) ) } else { cv_obj <- rsamp_obj } # Tune the workflow # Start parallel backed modeltime::parallel_start(num_cores) tuned_results <- wflw %>% tune::tune_grid( resamples = cv_obj, grid = tuning_grid_spec, metrics = ms ) modeltime::parallel_stop() # Get the best result set by a specified metric best_result_set <- tuned_results %>% tune::show_best(metric = best_metric, n = 1) # Plot results tune_results_plt <- tuned_results %>% tune::autoplot() + ggplot2::theme_minimal() + ggplot2::geom_smooth(se = FALSE) + ggplot2::theme(legend.position = "bottom") # Make final workflow wflw_fit <- wflw %>% tune::finalize_workflow( tuned_results %>% tune::show_best(metric = best_metric, n = 1) ) %>% parsnip::fit(rsample::training(splits)) } else { wflw_fit <- wflw %>% parsnip::fit(rsample::training(splits)) } # Return ---- output <- list( recipe_info = rec_obj, model_info = list( model_spec = model_spec, wflw = wflw, fitted_wflw = wflw_fit, was_tuned = ifelse(.tune, "tuned", "not_tuned") ) ) if (.tune){ output$tuned_info = list( tuning_grid = tuning_grid_spec, cv_obj = cv_obj, tuned_results = tuned_results, grid_size = grid_size, best_metric = best_metric, best_result_set = best_result_set, tuning_grid_plot = tune_results_plt, plotly_grid_plot = plotly::ggplotly(tune_results_plt) ) } return(invisible(output)) }
Example:
output $recipe_info Recipe Inputs: role #variables outcome 1 predictor 4 Operations: Novel factor level assignment for recipes::all_nominal_predictors() Dummy variables from recipes::all_nominal_predictors() Zero variance filter on recipes::all_predictors() Centering and scaling for recipes::all_numeric() $model_info $model_info$model_spec K-Nearest Neighbor Model Specification (classification) Main Arguments: neighbors = tune::tune() weight_func = tune::tune() dist_power = tune::tune() Computational engine: kknn $model_info$wflw == Workflow =============================================================================== Preprocessor: Recipe Model: nearest_neighbor() -- Preprocessor --------------------------------------------------------------------------- 4 Recipe Steps * step_novel() * step_dummy() * step_zv() * step_normalize() -- Model ---------------------------------------------------------------------------------- K-Nearest Neighbor Model Specification (classification) Main Arguments: neighbors = tune::tune() weight_func = tune::tune() dist_power = tune::tune() Computational engine: kknn $model_info$fitted_wflw == Workflow [trained] ===================================================================== Preprocessor: Recipe Model: nearest_neighbor() -- Preprocessor --------------------------------------------------------------------------- 4 Recipe Steps * step_novel() * step_dummy() * step_zv() * step_normalize() -- Model ---------------------------------------------------------------------------------- Call: kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(5L, data, 5), distance = ~1.58310485205147, kernel = ~"inv") Type of response variable: nominal Minimal misclassification: 0.03571429 Best kernel: inv Best k: 5 $model_info$was_tuned [1] "tuned" $tuned_info $tuned_info$tuning_grid # A tibble: 10 x 3 neighbors weight_func dist_power <int> <chr> <dbl> 1 2 triweight 1.11 2 10 gaussian 1.67 3 7 epanechnikov 0.667 4 12 optimal 0.730 5 4 rank 1.87 6 12 triangular 0.163 7 14 cos 1.28 8 5 inv 1.58 9 8 rectangular 0.315 10 4 biweight 0.917 $tuned_info$cv_obj # Monte Carlo cross-validation (0.75/0.25) with 25 resamples # A tibble: 25 x 2 splits id <list> <chr> 1 <split [84/28]> Resample01 2 <split [84/28]> Resample02 3 <split [84/28]> Resample03 4 <split [84/28]> Resample04 5 <split [84/28]> Resample05 6 <split [84/28]> Resample06 7 <split [84/28]> Resample07 8 <split [84/28]> Resample08 9 <split [84/28]> Resample09 10 <split [84/28]> Resample10 # ... with 15 more rows $tuned_info$tuned_results # Tuning results # Monte Carlo cross-validation (0.75/0.25) with 25 resamples # A tibble: 25 x 4 splits id .metrics .notes <list> <chr> <list> <list> 1 <split [84/28]> Resample01 <tibble [110 x 7]> <tibble [0 x 3]> 2 <split [84/28]> Resample02 <tibble [110 x 7]> <tibble [0 x 3]> 3 <split [84/28]> Resample03 <tibble [110 x 7]> <tibble [0 x 3]> 4 <split [84/28]> Resample04 <tibble [110 x 7]> <tibble [0 x 3]> 5 <split [84/28]> Resample05 <tibble [110 x 7]> <tibble [0 x 3]> 6 <split [84/28]> Resample06 <tibble [110 x 7]> <tibble [0 x 3]> 7 <split [84/28]> Resample07 <tibble [110 x 7]> <tibble [0 x 3]> 8 <split [84/28]> Resample08 <tibble [110 x 7]> <tibble [0 x 3]> 9 <split [84/28]> Resample09 <tibble [110 x 7]> <tibble [0 x 3]> 10 <split [84/28]> Resample10 <tibble [110 x 7]> <tibble [0 x 3]> # ... with 15 more rows $tuned_info$grid_size [1] 10 $tuned_info$best_metric [1] "f_meas" $tuned_info$best_result_set # A tibble: 1 x 9 neighbors weight_func dist_power .metric .estimator mean n std_err .config <int> <chr> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr> 1 5 inv 1.58 f_meas macro 0.963 25 0.00580 Preprocessor1_Mo~ $tuned_info$tuning_grid_plot `geom_smooth()` using method = 'loess' and formula 'y ~ x' $tuned_info$plotly_grid_plot
The text was updated successfully, but these errors were encountered:
hai_knn()
hai_knn_automl()
06f6dc8
Merge pull request #226 from spsanderson/development
718faf2
Fixes #223 Fixes #205
spsanderson
No branches or pull requests
Function:
Example:
The text was updated successfully, but these errors were encountered: