From 737983e3ca9db9c75cbff9ce44feee56ca8acbc9 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Sun, 20 Feb 2022 21:16:49 -0500 Subject: [PATCH] better tuning parameters for brulee --- R/mlp_data.R | 4 ++-- R/tunable.R | 26 ++++++++++++++++++++++++++ R/zzz.R | 1 + 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/R/mlp_data.R b/R/mlp_data.R index f5d9e7c4b..b912bfd0c 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -420,7 +420,7 @@ set_model_arg( eng = "brulee", parsnip = "learn_rate", original = "learn_rate", - func = list(pkg = "dials", fun = "learn_rate"), + func = list(pkg = "dials", fun = "learn_rate", range = c(-2.5, -0.5)), has_submodel = FALSE ) @@ -448,7 +448,7 @@ set_model_arg( eng = "brulee", parsnip = "activation", original = "activation", - func = list(pkg = "dials", fun = "activation"), + func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')), has_submodel = FALSE ) diff --git a/R/tunable.R b/R/tunable.R index bbb47621e..8a1c83375 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -137,6 +137,21 @@ earth_engine_args <- component_id = "engine" ) +brulee_engine_args <- + tibble::tibble( + name = c( + "batch_size", + "class_weights" + ), + call_info = list( + list(pkg = "dials", fun = "batch_size", range = c(5, 10)), + list(pkg = "dials", fun = "class_weights") + ), + source = "model_spec", + component = "mlp", + component_id = "engine" + ) + # ------------------------------------------------------------------------------ # Lazily registered in .onLoad() @@ -227,3 +242,14 @@ tunable_svm_poly <- function(x, ...) { } res } + + +# Lazily registered in .onLoad() +tunable_mlp <- function(x, ...) { + res <- NextMethod() + if (x$engine == "brulee") { + res <- add_engine_parameters(res, brulee_engine_args) + } + res +} + diff --git a/R/zzz.R b/R/zzz.R index 60e919b37..9fc2ef57c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -45,6 +45,7 @@ vctrs::s3_register("generics::tunable", "mars", tunable_mars) vctrs::s3_register("generics::tunable", "decision_tree", tunable_decision_tree) vctrs::s3_register("generics::tunable", "svm_poly", tunable_svm_poly) + vctrs::s3_register("generics::tunable", "mlp", tunable_mlp) } }