Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export(extract_preprocessor)
export(extract_recipe)
export(extract_spec_parsnip)
export(extract_workflow)
export(fct_encode_one_hot)
export(forge)
export(frequency_weights)
export(get_data_classes)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# hardhat (development version)

* New `fct_encode_one_hot()` that encodes a factor as a one-hot indicator matrix
(#215).

* Added more documentation about importance and frequency weights in
`?importance_weights()` and `?frequency_weights()` (#214).

Expand Down
57 changes: 57 additions & 0 deletions R/encoding.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#' Encode a factor as a one-hot indicator matrix
#'
#' @description
#' `fct_encode_one_hot()` encodes a factor as a one-hot indicator matrix.
#'
#' This matrix consists of `length(x)` rows and `length(levels(x))` columns.
#' Every value in row `i` of the matrix is filled with `0L` except for the
#' column that has the same name as `x[[i]]`, which is instead filled with `1L`.
#'
#' @details
#' The columns are returned in the same order as `levels(x)`.
#'
#' If `x` has names, the names are propagated onto the result as the row names.
#'
#' @param x A factor.
#'
#' `x` can't contain missing values.
#'
#' `x` is allowed to be an ordered factor.
#'
#' @return An integer matrix with `length(x)` rows and `length(levels(x))`
#' columns.
#'
#' @export
#' @examples
#' fct_encode_one_hot(factor(letters))
#'
#' fct_encode_one_hot(factor(letters[1:2], levels = letters))
#'
#' set.seed(1234)
#' fct_encode_one_hot(factor(sample(letters[1:4], 10, TRUE)))
fct_encode_one_hot <- function(x) {
if (!is.factor(x)) {
abort("`x` must be a factor.")
}

row_names <- names(x)
col_names <- levels(x)
dim_names <- list(row_names, col_names)

n_cols <- length(col_names)
n_rows <- length(x)

x <- unclass(x)

if (vec_any_missing(x)) {
abort("`x` can't contain missing values.")
}

out <- matrix(0L, nrow = n_rows, ncol = n_cols, dimnames = dim_names)

# Use integer matrix indexing to assign the `1`s
loc <- cbind(row = seq_len(n_rows), col = x)
out[loc] <- 1L

out
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ reference:
- new_model
- add_intercept_column
- weighted_table
- fct_encode_one_hot

- title: Validation
contents:
Expand Down
39 changes: 39 additions & 0 deletions man/fct_encode_one_hot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/encoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# errors on missing values

Code
fct_encode_one_hot(x)
Condition
Error in `fct_encode_one_hot()`:
! `x` can't contain missing values.

# errors on non-factors

Code
fct_encode_one_hot(1)
Condition
Error in `fct_encode_one_hot()`:
! `x` must be a factor.

75 changes: 75 additions & 0 deletions tests/testthat/test-encoding.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
test_that("generates one-hot indicator matrix", {
x <- factor(c("a", "b", "a", "a", "c"))

expect <- matrix(0L, nrow = 5, ncol = 3, dimnames = list(NULL, c("a", "b", "c")))
expect[c(1, 3, 4, 7, 15)] <- 1L

expect_identical(fct_encode_one_hot(x), expect)
})

test_that("works with factors with just 1 level", {
x <- factor(rep("a", 3))

expect_identical(
fct_encode_one_hot(x),
matrix(1L, nrow = 3, ncol = 1, dimnames = list(NULL, "a"))
)
})

test_that("works with levels that aren't in the data", {
x <- factor(c("a", "c", "a"), levels = c("a", "b", "c", "d"))

expect <- matrix(0L, nrow = 3, ncol = 4, dimnames = list(NULL, c("a", "b", "c", "d")))
expect[c(1, 3, 8)] <- 1L

expect_identical(fct_encode_one_hot(x), expect)
})

test_that("works with factors with explicit `NA` level but no `NA` data", {
expect_identical(
fct_encode_one_hot(factor("a", levels = c("a", NA), exclude = NULL)),
matrix(data = c(1L, 0L), nrow = 1, ncol = 2, dimnames = list(NULL, c("a", NA)))
)
})

test_that("works with empty factors", {
expect_identical(
fct_encode_one_hot(factor()),
matrix(data = integer(), nrow = 0, ncol = 0, dimnames = list(NULL, NULL))
)
})

test_that("works with empty factors with levels", {
expect_identical(
fct_encode_one_hot(factor(levels = c("a", "b"))),
matrix(data = integer(), nrow = 0, ncol = 2, dimnames = list(NULL, c("a", "b")))
)
})

test_that("propagates names onto the row names", {
x <- set_names(factor(c("a", "b", "a")), c("x", "y", "z"))
expect_identical(rownames(fct_encode_one_hot(x)), c("x", "y", "z"))
})

test_that("works with ordered factors", {
x <- factor(c("a", "b", "a", "a", "c"), levels = c("c", "b", "a"), ordered = TRUE)

expect <- matrix(0L, nrow = 5, ncol = 3, dimnames = list(NULL, c("c", "b", "a")))
expect[c(5, 7, 11, 13, 14)] <- 1L

expect_identical(fct_encode_one_hot(x), expect)
})

test_that("errors on missing values", {
x <- factor(c("a", NA))

expect_snapshot(error = TRUE, {
fct_encode_one_hot(x)
})
})

test_that("errors on non-factors", {
expect_snapshot(error = TRUE, {
fct_encode_one_hot(1)
})
})