Skip to content

Commit

Permalink
Merge pull request #213 from tidymodels/new-arguments-to-step_umap
Browse files Browse the repository at this point in the history
New arguments to step umap
  • Loading branch information
EmilHvitfeldt committed Jan 17, 2024
2 parents 311d068 + f973fce commit f205fd8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# embed (development version)

* `step_umap()` has gained `initial` and `target_weight` arguments. (#213)

# embed 1.1.3

* `step_collapse_stringdist()` will now return predictors as factors. (#204)
Expand Down
30 changes: 28 additions & 2 deletions R/umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
#' neighbors. See [uwot::umap()] for more details. Default to `"euclidean"`.
#' @param epochs Number of iterations for the neighbor optimization. See
#' [uwot::umap()] for more details.
#' @param initial Character, Type of initialization for the coordinates. Can be
#' one of `"spectral"`, `"normlaplacian"`, `"random"`, `"lvrandom"`,
#' `"laplacian"`, `"pca"`, `"spca"`, `"agspectral"`, or a matrix of initial
#' coordinates. See [uwot::umap()] for more details. Default to `"spectral"`.
#' @param target_weight Weighting factor between data topology and target
#' topology. A value of 0.0 weights entirely on data, a value of 1.0 weights
#' entirely on target. The default of 0.5 balances the weighting equally
#' between data and target.
#' @param learn_rate Positive number of the learning rate for the optimization
#' process.
#' @param outcome A call to `vars` to specify which variable is used as the
Expand Down Expand Up @@ -105,6 +113,8 @@ step_umap <-
metric = "euclidean",
learn_rate = 1,
epochs = NULL,
initial = "spectral",
target_weight = 0.5,
options = list(verbose = FALSE, n_threads = 1),
seed = sample(10^5, 2),
prefix = "UMAP",
Expand Down Expand Up @@ -143,6 +153,8 @@ step_umap <-
metric = metric,
learn_rate = learn_rate,
epochs = epochs,
initial = initial,
target_weight = target_weight,
options = options,
seed = seed,
prefix = prefix,
Expand All @@ -157,8 +169,8 @@ step_umap <-

step_umap_new <-
function(terms, role, trained, outcome, neighbors, num_comp, min_dist, metric,
learn_rate, epochs, options, seed, prefix, keep_original_cols,
retain, object, skip, id) {
learn_rate, epochs, initial, target_weight, options, seed, prefix,
keep_original_cols, retain, object, skip, id) {
step(
subclass = "umap",
terms = terms,
Expand All @@ -171,6 +183,8 @@ step_umap_new <-
metric = metric,
learn_rate = learn_rate,
epochs = epochs,
initial = initial,
target_weight = target_weight,
options = options,
seed = seed,
prefix = prefix,
Expand All @@ -194,9 +208,11 @@ umap_fit_call <- function(obj, y = NULL) {
cl$n_neighbors <- obj$neighbors
cl$n_components <- obj$num_comp
cl$n_epochs <- obj$epochs
cl$init <- obj$initial
cl$learning_rate <- obj$learn_rate
cl$min_dist <- obj$min_dist
cl$metric <- obj$metric
cl$target_weight <- obj$target_weight
if (length(obj$options) > 0) {
cl <- rlang::call_modify(cl, !!!obj$options)
}
Expand All @@ -216,6 +232,14 @@ prep.step_umap <- function(x, training, info = NULL, ...) {
}
x$neighbors <- min(nrow(training) - 1, x$neighbors)
x$num_comp <- min(length(col_names) - 1, x$num_comp)

if (is.null(x$initial)) {
x$initial <- "spectral"
}
if (is.null(x$target_weight)) {
x$target_weight <- 0.5
}

withr::with_seed(
x$seed[1],
res <- rlang::eval_tidy(umap_fit_call(x, y = y_name))
Expand All @@ -237,6 +261,8 @@ prep.step_umap <- function(x, training, info = NULL, ...) {
metric = x$metric,
learn_rate = x$learn_rate,
epochs = x$epochs,
initial = x$initial,
target_weight = x$target_weight,
options = x$options,
seed = x$seed,
prefix = x$prefix,
Expand Down
12 changes: 12 additions & 0 deletions man/step_umap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/testthat/test-umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,24 @@ test_that("tunable", {
)
})

test_that("backwards compatible for initial and target_weight args (#213)", {
skip_if_not_installed("irlba", "2.3.5.2")

rec <- recipe(Species ~ ., data = tr) %>%
step_umap(all_predictors(), num_comp = 2)

exp_res <- prep(rec)

rec$steps[[1]]$initial <- NULL
rec$steps[[1]]$target_weight <- NULL

expect_identical(
prep(rec),
exp_res
)
})


# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
Expand Down

0 comments on commit f205fd8

Please sign in to comment.