Skip to content

Commit

Permalink
fix some parameter mappings between parsnip and the underlying model …
Browse files Browse the repository at this point in the history
…function for #238
  • Loading branch information
topepo committed Dec 2, 2019
1 parent 42d5ba5 commit 897c927
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ set_model_arg(
model = "boost_tree",
eng = "spark",
parsnip = "min_info_gain",
original = "gamma",
original = "loss_reduction",
func = list(pkg = "dials", fun = "loss_reduction"),
has_submodel = FALSE
)
Expand Down
9 changes: 9 additions & 0 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,15 @@ set_model_engine("linear_reg", "regression", "keras")
set_dependency("linear_reg", "keras", "keras")
set_dependency("linear_reg", "keras", "magrittr")

set_model_arg(
model = "linear_reg",
eng = "keras",
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

set_fit(
model = "linear_reg",
eng = "keras",
Expand Down
6 changes: 3 additions & 3 deletions R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ set_dependency("logistic_reg", "keras", "magrittr")
set_model_arg(
model = "logistic_reg",
eng = "keras",
parsnip = "decay",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

Expand Down
20 changes: 10 additions & 10 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,12 @@ class2ind <- function (x, drop2nd = FALSE) {
#' @param x A data frame or matrix of predictors
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
#' @param hidden_units An integer for the number of hidden units.
#' @param decay A non-negative real number for the amount of weight decay. Either
#' @param penalty A non-negative real number for the amount of weight decay. Either
#' this parameter _or_ `dropout` can specified.
#' @param dropout The proportion of parameters to set to zero. Either
#' this parameter _or_ `decay` can specified.
#' this parameter _or_ `penalty` can specified.
#' @param epochs An integer for the number of passes through the data.
#' @param act A character string for the type of activation function between layers.
#' @param activation A character string for the type of activation function between layers.
#' @param seeds A vector of three positive integers to control randomness of the
#' calculations.
#' @param ... Currently ignored.
Expand All @@ -279,11 +279,11 @@ class2ind <- function (x, drop2nd = FALSE) {
#' @export
keras_mlp <-
function(x, y,
hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax",
hidden_units = 5, penalty = 0, dropout = 0, epochs = 20, activation = "softmax",
seeds = sample.int(10^5, size = 3),
...) {

if (decay > 0 & dropout > 0) {
if (penalty > 0 & dropout > 0) {
stop("Please use either dropoput or weight decay.", call. = FALSE)
}
if (!is.matrix(x)) {
Expand All @@ -307,20 +307,20 @@ keras_mlp <-

model <- keras::keras_model_sequential()

if (decay > 0) {
if (penalty > 0) {
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_regularizer = keras::regularizer_l2(decay),
kernel_regularizer = keras::regularizer_l2(penalty),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
)
} else {
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
)
Expand All @@ -330,7 +330,7 @@ keras_mlp <-
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
) %>%
Expand Down
4 changes: 2 additions & 2 deletions R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set_model_arg(
eng = "keras",
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "weight_decay"),
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)
set_model_arg(
Expand Down Expand Up @@ -188,7 +188,7 @@ set_model_arg(
eng = "nnet",
parsnip = "penalty",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)
set_model_arg(
Expand Down
6 changes: 3 additions & 3 deletions R/multinom_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ set_dependency("multinom_reg", "keras", "magrittr")
set_model_arg(
model = "multinom_reg",
eng = "keras",
parsnip = "decay",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

Expand Down
10 changes: 5 additions & 5 deletions man/keras_mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 897c927

Please sign in to comment.