Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hai_auto_knn() #223

Closed
Tracked by #205
spsanderson opened this issue Apr 28, 2022 · 0 comments
Closed
Tracked by #205

hai_auto_knn() #223

spsanderson opened this issue Apr 28, 2022 · 0 comments
Assignees
Labels
enhancement New feature or request function A new function

Comments

@spsanderson
Copy link
Owner

spsanderson commented Apr 28, 2022

Function:

hai_auto_knn <- function(.data, .rec_obj, .splits_obj = NULL, .rsamp_obj = NULL, 
                         .tune = TRUE, .grid_size = 10, .num_cores = 1, 
                         .best_metric = "rmse", .model_type = "regression"){
  
  # Tidyeval ----
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)
  
  data_tbl <- dplyr::as_tibble(.data)
  
  splits <- .splits_obj
  rec_obj <- .rec_obj
  rsamp_obj <- .rsamp_obj
  model_type <- as.character(.model_type)
  
  # Checks ----
  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must have a class of 'rsplit', use the rsample package.",
      use_cli_format = TRUE
    )
  }
  
  if (!inherits(x = rec_obj, what = "recipe")){
    rlang::abort(
      message = "'.rec_obj' must have a class of 'recipe'."
    )
  }
  
  if (!model_type %in% c("regression","classification")){
    rlang::abort(
      message = paste0(
        "You chose a mode of: '",
        model_type,
        "' this is unsupported. Choose from either 'regression' or 'classification'."
      ),
      use_cli_format = TRUE
    )
  }
  
  if (!inherits(x = rsamp_obj, what = "rset") && !is.null(rsamp_obj)){
    rlang::abort(
      message = "The '.rsamp_obj' argument must either be NULL or an object of 
      calss 'rset'.",
      use_cli_format = TRUE
    )
  }
  
  if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){
    rlang::abort(
      message = "The '.splits_obj' argument must either be NULL or an object of
      class 'rsplit'",
      use_cli_format = TRUE
    )
  }
  
  # Set default metric set ----
  if (model_type == "classification"){
    ms <- hai_default_classification_metric_set()
  } else {
    ms <- hai_default_regression_metric_set()
  }
  
  # Get splits if not then create
  if (is.null(splits)){
    splits <- rsample::initial_split(data = data_tbl)
  } else {
    splits <- splits
  }
  
  # Tune/Spec ----
  if (.tune){
    # Model Specification
    model_spec <- parsnip::nearest_neighbor(
      neighbors = tune::tune(), 
      weight_func = tune::tune(),
      dist_power = tune::tune()
    )
  } else {
    model_spec <- parsnip::nearest_neighbor()
  }
  
  # Model Specification ----
  model_spec <- model_spec %>%
    parsnip::set_mode(mode = model_type) %>%
    parsnip::set_engine(engine = "kknn")
  
  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec)
  
  # Tuning Grid ---
  if (.tune){
    
    # Make tuning grid
    tuning_grid_spec <- dials::grid_latin_hypercube(
      hardhat::extract_parameter_set_dials(model_spec),
      size = grid_size
    )
    
    # Cross validation object
    if (is.null(rsamp_obj)){
      cv_obj <- rsample::mc_cv(
        data = rsample::training(splits)
      )
    } else {
      cv_obj <- rsamp_obj
    }
    
    # Tune the workflow
    # Start parallel backed
    modeltime::parallel_start(num_cores)
    
    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = cv_obj,
        grid      = tuning_grid_spec,
        metrics   = ms
      )
    
    modeltime::parallel_stop()
    
    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)
    
    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() +
      ggplot2::geom_smooth(se = FALSE) +
      ggplot2::theme(legend.position = "bottom")
    
    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = 1)
      ) %>%
      parsnip::fit(rsample::training(splits))
    
  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }
  
  # Return ----
  output <- list(
    recipe_info = rec_obj,
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    )
  )
  
  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      cv_obj           = cv_obj,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }
  
  return(invisible(output))
  
}

Example:

output
$recipe_info
Recipe

Inputs:

      role #variables
   outcome          1
 predictor          4

Operations:

Novel factor level assignment for recipes::all_nominal_predictors()
Dummy variables from recipes::all_nominal_predictors()
Zero variance filter on recipes::all_predictors()
Centering and scaling for recipes::all_numeric()

$model_info
$model_info$model_spec
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune::tune()
  weight_func = tune::tune()
  dist_power = tune::tune()

Computational engine: kknn 


$model_info$wflw
== Workflow ===============================================================================
Preprocessor: Recipe
Model: nearest_neighbor()

-- Preprocessor ---------------------------------------------------------------------------
4 Recipe Steps

* step_novel()
* step_dummy()
* step_zv()
* step_normalize()

-- Model ----------------------------------------------------------------------------------
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune::tune()
  weight_func = tune::tune()
  dist_power = tune::tune()

Computational engine: kknn 


$model_info$fitted_wflw
== Workflow [trained] =====================================================================
Preprocessor: Recipe
Model: nearest_neighbor()

-- Preprocessor ---------------------------------------------------------------------------
4 Recipe Steps

* step_novel()
* step_dummy()
* step_zv()
* step_normalize()

-- Model ----------------------------------------------------------------------------------

Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(5L,     data, 5), distance = ~1.58310485205147, kernel = ~"inv")

Type of response variable: nominal
Minimal misclassification: 0.03571429
Best kernel: inv
Best k: 5

$model_info$was_tuned
[1] "tuned"


$tuned_info
$tuned_info$tuning_grid
# A tibble: 10 x 3
   neighbors weight_func  dist_power
       <int> <chr>             <dbl>
 1         2 triweight         1.11 
 2        10 gaussian          1.67 
 3         7 epanechnikov      0.667
 4        12 optimal           0.730
 5         4 rank              1.87 
 6        12 triangular        0.163
 7        14 cos               1.28 
 8         5 inv               1.58 
 9         8 rectangular       0.315
10         4 biweight          0.917

$tuned_info$cv_obj
# Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
# A tibble: 25 x 2
   splits          id        
   <list>          <chr>     
 1 <split [84/28]> Resample01
 2 <split [84/28]> Resample02
 3 <split [84/28]> Resample03
 4 <split [84/28]> Resample04
 5 <split [84/28]> Resample05
 6 <split [84/28]> Resample06
 7 <split [84/28]> Resample07
 8 <split [84/28]> Resample08
 9 <split [84/28]> Resample09
10 <split [84/28]> Resample10
# ... with 15 more rows

$tuned_info$tuned_results
# Tuning results
# Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
# A tibble: 25 x 4
   splits          id         .metrics           .notes          
   <list>          <chr>      <list>             <list>          
 1 <split [84/28]> Resample01 <tibble [110 x 7]> <tibble [0 x 3]>
 2 <split [84/28]> Resample02 <tibble [110 x 7]> <tibble [0 x 3]>
 3 <split [84/28]> Resample03 <tibble [110 x 7]> <tibble [0 x 3]>
 4 <split [84/28]> Resample04 <tibble [110 x 7]> <tibble [0 x 3]>
 5 <split [84/28]> Resample05 <tibble [110 x 7]> <tibble [0 x 3]>
 6 <split [84/28]> Resample06 <tibble [110 x 7]> <tibble [0 x 3]>
 7 <split [84/28]> Resample07 <tibble [110 x 7]> <tibble [0 x 3]>
 8 <split [84/28]> Resample08 <tibble [110 x 7]> <tibble [0 x 3]>
 9 <split [84/28]> Resample09 <tibble [110 x 7]> <tibble [0 x 3]>
10 <split [84/28]> Resample10 <tibble [110 x 7]> <tibble [0 x 3]>
# ... with 15 more rows

$tuned_info$grid_size
[1] 10

$tuned_info$best_metric
[1] "f_meas"

$tuned_info$best_result_set
# A tibble: 1 x 9
  neighbors weight_func dist_power .metric .estimator  mean     n std_err .config          
      <int> <chr>            <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>            
1         5 inv               1.58 f_meas  macro      0.963    25 0.00580 Preprocessor1_Mo~

$tuned_info$tuning_grid_plot
`geom_smooth()` using method = 'loess' and formula 'y ~ x'

$tuned_info$plotly_grid_plot

image

@spsanderson spsanderson self-assigned this Apr 28, 2022
@spsanderson spsanderson added enhancement New feature or request function A new function labels Apr 28, 2022
@spsanderson spsanderson added this to the healthyR.ai 0.0.7 milestone Apr 28, 2022
spsanderson added a commit that referenced this issue Apr 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request function A new function
Development

No branches or pull requests

1 participant