Skip to content
Merged

Seed #442

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Suggests:
keras,
tfestimators,
callr
RoxygenNote: 7.0.2
RoxygenNote: 7.1.1
Config/reticulate:
list(
packages = list(
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export(np_array)
export(parse_arguments)
export(parse_flags)
export(run_dir)
export(set_random_seed)
export(shape)
export(tensorboard)
export(tf)
Expand Down
57 changes: 55 additions & 2 deletions R/seed.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#' and TensorFlow. GPU computations and CPU parallelism will also be disabled by
#' default.
#'
#' @inheritParams reticulate py_set_seed
#'
#' @param seed A single value, interpreted as an integer
#' @param disable_gpu `TRUE` to disable GPU execution (see *Parallelism* below).
Expand Down Expand Up @@ -50,8 +49,15 @@ use_session_with_seed <- function(seed,
disable_parallel_cpu = TRUE,
quiet = FALSE) {

if (tf_version() >= "2.0")

msg <- "use_session_with_seed will be deprecated in the future. use set_random_seed instead."
if (tf_version() >= "2.0") {
tf <- tf$compat$v1
warning(msg)
}

if (tf_version() >= "2.3")
stop(msg)

# cast seed to integer
seed <- as.integer(seed)
Expand Down Expand Up @@ -123,3 +129,50 @@ use_session_with_seed <- function(seed,
# return session invisibly
invisible(sess)
}

#' Set random seed for TensorFlow
#'
#' Sets all random seeds needed to make TensorFlow code reproducible.
#'
#' @details
#'
#' This function should be used instead of [use_session_with_seed()] if
#' you are using TensorFlow >= 2.0, as the concept of `session` doesn't
#' really make sense anymore.
#'
#' This functions sets:
#'
#' - The R random seed with [set.seed()].
#' - The python and Numpy seeds via ([reticulate::py_set_seed()]).
#' - The TensorFlow seed with (`tf$random$set_seed()`)
#'
#' It also optionally disables the GPU execution as this is a potential
#' source of non-reproducibility.
#'
#' @param seed A single value, interpreted as an integer
#' @param disable_gpu `TRUE` to disable GPU execution (see *Parallelism* below).
#'
#' @export
set_random_seed <- function(seed, disable_gpu = TRUE) {

if (tf_version() < "2.0")
stop("set_random_seed only works for TF >= 2.0")

# cast seed to integer
seed <- as.integer(seed)

# set R random seed
set.seed(seed)

# set Python/NumPy random seed
py_set_seed(seed)

# set tensorflow random seed
tensorflow::tf$random$set_seed(seed)

if (disable_gpu) {
Sys.setenv(CUDA_VISIBLE_DEVICES = "-1")
}

invisible(NULL)
}
4 changes: 2 additions & 2 deletions man/install_tensorflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions man/set_random_seed.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/tf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/tf_extract_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-seed.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
test_that("use_session_with_seed works", {
skip_if_no_tensorflow()

if (tf_version() >= "2.3")
skip("use_session_with seed doesn't work with TF >= 2.3")

f <- function() {
library(keras)
use_session_with_seed(seed = 1)
Expand All @@ -14,3 +17,24 @@ test_that("use_session_with_seed works", {

expect_equal(run1, run2)
})

test_that("set_random_seed", {

skip_if_no_tensorflow()

if (tf_version() < "2.0")
skip("set_random_seed only works for TF >= 2.0")

f <- function() {
library(keras)
tensorflow::set_random_seed(seed = 1)
model <- keras_model_sequential() %>%
layer_dense(units = 1)
predict(model, matrix(1, ncol = 1))
}

run1 <- callr::r(f)
run2 <- callr::r(f)

expect_equal(run1, run2)
})