Skip to content

Commit

Permalink
Merge pull request #120 from rstudio/feature/set-session-seed
Browse files Browse the repository at this point in the history
function to set random seeds for keras session
  • Loading branch information
jjallaire committed Sep 6, 2017
2 parents 2aafe72 + e8dca1c commit 0c95b04
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 22 deletions.
9 changes: 6 additions & 3 deletions DESCRIPTION
@@ -1,7 +1,7 @@
Package: keras
Type: Package
Title: R Interface to Keras
Version: 2.0.7.9003
Version: 2.0.7.9004
Authors@R: c(
person("JJ", "Allaire", role = c("aut", "cre"), email = "jj@rstudio.com"),
person("François", "Chollet", role = c("aut", "cph")),
Expand All @@ -23,8 +23,8 @@ BugReports: https://github.com/rstudio/keras/issues
Depends:
R (>= 3.2)
Imports:
reticulate (>= 1.1),
tensorflow (>= 1.3.1),
reticulate (>= 1.1.0.9003),
tensorflow (>= 1.3.1.9000),
tfruns (>= 0.9.1),
magrittr,
methods,
Expand All @@ -36,6 +36,9 @@ Suggests:
testthat,
knitr,
rmarkdown
Remotes:
rstudio/reticulate,
rstudio/tensorflow
SystemRequirements: Keras >= 2.0 (https://keras.io)
Roxygen: list(markdown = TRUE)
RoxygenNote: 6.0.1
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Expand Up @@ -247,6 +247,7 @@ export(train_on_batch)
export(unserialize_model)
export(use_condaenv)
export(use_python)
export(use_session_with_seed)
export(use_virtualenv)
export(xception_preprocess_input)
import(R6)
Expand Down Expand Up @@ -279,6 +280,7 @@ importFrom(tensorflow,install_tensorflow)
importFrom(tensorflow,tensorboard)
importFrom(tensorflow,tf_config)
importFrom(tensorflow,tf_version)
importFrom(tensorflow,use_session_with_seed)
importFrom(tfruns,flag_boolean)
importFrom(tfruns,flag_integer)
importFrom(tfruns,flag_numeric)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Expand Up @@ -3,6 +3,8 @@

Install the development version with: `install_github("rstudio/keras")`

- Add `set_keras_seed()` function that establishes a random seed for the Keras session.

- Fix for plotting training history with early stopping callback (thanks to @JamesAllingham).

- Better support for training models from data tensors in TensorFlow (e.g. Datasets, TFRecords). Add a related example script.
Expand Down
4 changes: 4 additions & 0 deletions R/package.R
Expand Up @@ -56,6 +56,10 @@ keras <- NULL
stop(e, call. = FALSE)
}
))

# tensorflow use_session hooks
setHook("tensorflow.on_before_use_session", tensorflow_on_before_use_session)
setHook("tensorflow.on_use_session", tensorflow_on_use_session)
}

resolve_implementation_module <- function() {
Expand Down
4 changes: 4 additions & 0 deletions R/reexports.R
Expand Up @@ -23,6 +23,10 @@ reticulate::use_virtualenv
#' @export
reticulate::use_condaenv

#' @importFrom tensorflow use_session_with_seed
#' @export
tensorflow::use_session_with_seed

#' @importFrom tensorflow tensorboard
#' @export
tensorflow::tensorboard
Expand Down
16 changes: 16 additions & 0 deletions R/seed.R
@@ -0,0 +1,16 @@


tensorflow_on_before_use_session <- function(quiet) {
if (is_backend("tensorflow")) {
keras$backend$clear_session()
TRUE
} else {
FALSE
}
}

tensorflow_on_use_session <- function(sess, quiet) {
if (is_backend("tensorflow"))
keras$backend$set_session(sess)
}

4 changes: 3 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkgdown/_pkgdown.yml
Expand Up @@ -290,6 +290,7 @@ reference:
- get_file
- backend
- implementation
- use_session_with_seed
- install_keras
- is_keras_available

Expand Down
32 changes: 14 additions & 18 deletions vignettes/faq.Rmd
Expand Up @@ -442,32 +442,28 @@ If you create [custom layers in R](custom_layers.html) or import other Python pa

During development of a model, sometimes it is useful to be able to obtain reproducible results from run to run in order to determine if a change in performance is due to an actual model or data modification, or merely a result of a new random sample.

The below snippet of code provides an example of how to obtain reproducible results when using the TensorFlow backend. To do this we set the R session's random seed, then manually construct a TensorFlow session (via the **tensorflow** package) and set it's random seed, and then finally arrange for Keras to use this session within its backend.
The `use_session_with_seed()` function establishes a common random seed for R, Python, NumPy, and TensorFlow. It furthermore disables hash randomization, GPU computations, and CPU parallelization,
which can be additional sources of non-reproducibility.

The `use_session_with_seed()` function is available in the development version of the Keras package, which you can install as follows:

```{r}
library(keras)
library(tensorflow)
devtools::install_github("rstudio/keras")
```

# Set R random seed
set.seed(42L)
To use the function, call it immediately after you load the keras package:

# TensorFlow session configuration that uses only a single thread. Multiple threads are a
# potential source of non-reproducible results, see: https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
session_conf <- tf$ConfigProto(intra_op_parallelism_threads = 1L,
inter_op_parallelism_threads = 1L)
```{r}
library(keras)
use_session_with_seed(42)
# Set TF random seed (see: https://www.tensorflow.org/api_docs/python/tf/set_random_seed)
tf$set_random_seed(1042L)
# ...rest of code follows...
# Create the session using the custom configuration
sess <- tf$Session(graph = tf$get_default_graph(), config = session_conf)
```

# Instruct Keras to use this session
K <- backend()
K$set_session(sess)
This function takes all measures known to promote reproducible results from Keras sessions, however it's possible that various individual features or libraries used by the backend escape its effects. If you encounter non-reproducible results please investigate the possible sources of the problem. The source code for `use_session_with_seed()` is here: https://github.com/rstudio/tensorflow/blob/master/R/seed.R Contributions via pull request are very welcome!

# Rest of code follows ...
```
Please note again that `use_session_with_seed()` disables GPU computations and CPU parallelization by default (as both can lead to non-deterministic computations) so should generally not be used when model training time is paramount.

## Where is the Keras configuration filed stored?

Expand Down

0 comments on commit 0c95b04

Please sign in to comment.