From 1c9e2ba55cd405e8d973a564516170fd4ab76e52 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 26 Jan 2022 14:05:37 -0800 Subject: [PATCH 01/26] Add keras_predict_classes to replace use of keras::predict_classes --- NAMESPACE | 1 + R/mlp.R | 10 ++++++++-- man/keras_predict_classes.Rd | 17 +++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 man/keras_predict_classes.Rd diff --git a/NAMESPACE b/NAMESPACE index bba504820..53eb308d5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -206,6 +206,7 @@ export(glance) export(has_multi_predict) export(is_varying) export(keras_mlp) +export(keras_predict_classes) export(linear_reg) export(logistic_reg) export(make_call) diff --git a/R/mlp.R b/R/mlp.R index 2a72a9015..45d60e9e2 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -438,5 +438,11 @@ reformat_torch_num <- function(results, object) { results } - - +#' Wrapper for keras class predictions +#' @param object A keras model fit +#' @param x A data set. +#' @export +#' @keywords internal +keras_predict_classes <- function(object, x) { + object %>% predict(x) %>% keras::k_argmax() %>% as.integer() +} diff --git a/man/keras_predict_classes.Rd b/man/keras_predict_classes.Rd new file mode 100644 index 000000000..d0677de97 --- /dev/null +++ b/man/keras_predict_classes.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlp.R +\name{keras_predict_classes} +\alias{keras_predict_classes} +\title{Wrapper for keras class predictions} +\usage{ +keras_predict_classes(object, x) +} +\arguments{ +\item{object}{A keras model fit} + +\item{x}{A data set.} +} +\description{ +Wrapper for keras class predictions +} +\keyword{internal} From 0174f9859e953a9dc7916b737c39c20877d6d2d4 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 26 Jan 2022 14:05:50 -0800 Subject: [PATCH 02/26] use keras_predict_classes --- R/logistic_reg_data.R | 2 +- R/mlp_data.R | 2 +- R/multinom_reg_data.R | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index b30828254..68c4ca098 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -441,7 +441,7 @@ set_pred( post = function(x, object) { object$lvl[x + 1] }, - func = c(pkg = "keras", fun = "predict_classes"), + func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = list( object = quote(object$fit), diff --git a/R/mlp_data.R b/R/mlp_data.R index 62c2a0cd4..54fefb8fe 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -149,7 +149,7 @@ set_pred( post = function(x, object) { object$lvl[x + 1] }, - func = c(pkg = "keras", fun = "predict_classes"), + func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = list( object = quote(object$fit), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index fab239bc3..30a0d4aca 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -237,7 +237,7 @@ set_pred( post = function(x, object) { object$lvl[x + 1] }, - func = c(pkg = "keras", fun = "predict_classes"), + func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) From 06449a70611acdec3ebd7566b108489b3116e550 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 26 Jan 2022 14:06:12 -0800 Subject: [PATCH 03/26] update mlp_keras reference calculations in tests --- tests/testthat/test_mlp_keras.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 629c602ff..9c7206119 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -74,7 +74,7 @@ test_that('keras classification prediction', { control = ctrl ) - xy_pred <- keras::predict_classes(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + xy_pred <- predict(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) %>% keras::k_argmax() %>% as.integer() xy_pred <- factor(levels(hpc$class)[xy_pred + 1], levels = levels(hpc$class)) expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -87,7 +87,7 @@ test_that('keras classification prediction', { control = ctrl ) - form_pred <- keras::predict_classes(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + form_pred <- predict(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) %>% keras::k_argmax() %>% as.integer() form_pred <- factor(levels(hpc$class)[form_pred + 1], levels = levels(hpc$class)) expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) From dafc0da1612c8b2fb6c11505a9a96c2ecac75022 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 26 Jan 2022 14:31:00 -0800 Subject: [PATCH 04/26] move from predict_proba() to predict() --- R/logistic_reg_data.R | 2 +- R/mlp_data.R | 2 +- R/multinom_reg_data.R | 2 +- tests/testthat/test_logistic_reg_keras.R | 6 ++---- tests/testthat/test_mlp_keras.R | 4 ++-- tests/testthat/test_multinom_reg_keras.R | 4 ++-- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 68c4ca098..8e6d2f1c9 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -462,7 +462,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "keras", fun = "predict_proba"), + func = c(fun = "predict"), args = list( object = quote(object$fit), diff --git a/R/mlp_data.R b/R/mlp_data.R index 54fefb8fe..5b7f69863 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -170,7 +170,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "keras", fun = "predict_proba"), + func = c(fun = "predict"), args = list( object = quote(object$fit), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 30a0d4aca..5fa959dbe 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -256,7 +256,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "keras", fun = "predict_proba"), + func = c(fun = "predict"), args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 7fee3c5c4..4eb8f8614 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -163,7 +163,7 @@ test_that('classification probabilities', { y = tr_dat$Class ) - keras_pred <- keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) + keras_pred <- predict(lr_fit$fit, as.matrix(te_dat[, -1])) colnames(keras_pred) <- paste0(".pred_", lr_fit$lvl) keras_pred <- as_tibble(keras_pred) @@ -179,7 +179,7 @@ test_that('classification probabilities', { y = tr_dat$Class ) - keras_pred <- keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) + keras_pred <- predict(plrfit$fit, as.matrix(te_dat[, -1])) colnames(keras_pred) <- paste0(".pred_", lr_fit$lvl) keras_pred <- as_tibble(keras_pred) @@ -187,5 +187,3 @@ test_that('classification probabilities', { expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) }) - - diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 9c7206119..04e210897 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -106,7 +106,7 @@ test_that('keras classification probabilities', { control = ctrl ) - xy_pred <- keras::predict_proba(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + xy_pred <- predict(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) colnames(xy_pred) <- paste0(".pred_", levels(hpc$class)) xy_pred <- as_tibble(xy_pred) expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "prob")) @@ -120,7 +120,7 @@ test_that('keras classification probabilities', { control = ctrl ) - form_pred <- keras::predict_proba(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + form_pred <- predict(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) colnames(form_pred) <- paste0(".pred_", levels(hpc$class)) form_pred <- as_tibble(form_pred) expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "prob")) diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index e504f9d85..c8a59ace8 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -160,7 +160,7 @@ test_that('classification probabilities', { ) keras_pred <- - keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% + predict(lr_fit$fit, as.matrix(te_dat[, -5])) %>% as_tibble(.name_repair = "minimal") %>% setNames(paste0(".pred_", lr_fit$lvl)) @@ -177,7 +177,7 @@ test_that('classification probabilities', { ) keras_pred <- - keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% + predict(plrfit$fit, as.matrix(te_dat[, -5])) %>% as_tibble(.name_repair = "minimal") %>% setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob") From 79f4d72519e7839674b03e99c5165b51ded20de5 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 09:39:47 -0800 Subject: [PATCH 05/26] Create conditional tensorflow checking --- DESCRIPTION | 3 ++- NAMESPACE | 1 + R/logistic_reg_data.R | 2 +- R/mlp.R | 20 +++++++++++++++++++- R/mlp_data.R | 2 +- R/multinom_reg_data.R | 2 +- man/keras_predict_proba.Rd | 17 +++++++++++++++++ 7 files changed, 42 insertions(+), 5 deletions(-) create mode 100644 man/keras_predict_proba.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 716857000..2d13a3c1c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -39,6 +39,7 @@ Suggests: covr, dials (>= 0.0.10.9001), earth, + tensorflow, ggplot2, keras, kernlab, @@ -84,6 +85,6 @@ Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.2 -Remotes: +Remotes: tidymodels/dials, tidymodels/hardhat diff --git a/NAMESPACE b/NAMESPACE index 53eb308d5..5f5bb0469 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -207,6 +207,7 @@ export(has_multi_predict) export(is_varying) export(keras_mlp) export(keras_predict_classes) +export(keras_predict_proba) export(linear_reg) export(logistic_reg) export(make_call) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 8e6d2f1c9..d2a49653b 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -462,7 +462,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(fun = "predict"), + func = c(pkg = "parsnip", fun = "keras_predict_proba"), args = list( object = quote(object$fit), diff --git a/R/mlp.R b/R/mlp.R index 45d60e9e2..d62a65784 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -444,5 +444,23 @@ reformat_torch_num <- function(results, object) { #' @export #' @keywords internal keras_predict_classes <- function(object, x) { - object %>% predict(x) %>% keras::k_argmax() %>% as.integer() + if (tensorflow::tf_version() >= package_version("2.6")) { + object %>% predict(x) %>% keras::k_argmax() %>% as.integer() + } else { + keras::predict_classes(object, x) + } +} + +#' Wrapper for keras class probability predictions +#' @param object A keras model fit +#' @param x A data set. +#' @export +#' @keywords internal +keras_predict_proba <- function(object, x) { + if (tensorflow::tf_version() >= package_version("2.6")) { + object %>% predict(x) + } else { + keras::predict_proba(object, x) + } + } diff --git a/R/mlp_data.R b/R/mlp_data.R index 5b7f69863..5d11199c1 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -170,7 +170,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(fun = "predict"), + func = c(pkg = "parsnip", fun = "keras_predict_proba"), args = list( object = quote(object$fit), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 5fa959dbe..8719dd5a2 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -256,7 +256,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(fun = "predict"), + func = c(pkg = "parsnip", fun = "keras_predict_proba"), args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) diff --git a/man/keras_predict_proba.Rd b/man/keras_predict_proba.Rd new file mode 100644 index 000000000..45e94b4e1 --- /dev/null +++ b/man/keras_predict_proba.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlp.R +\name{keras_predict_proba} +\alias{keras_predict_proba} +\title{Wrapper for keras class probability predictions} +\usage{ +keras_predict_proba(object, x) +} +\arguments{ +\item{object}{A keras model fit} + +\item{x}{A data set.} +} +\description{ +Wrapper for keras class probability predictions +} +\keyword{internal} From e609d64b11a206d03928cc513b47c0dc733cd9e3 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 09:42:26 -0800 Subject: [PATCH 06/26] use newer version of tensorflow in GHA --- .github/workflows/R-CMD-check.yaml | 2 +- .github/workflows/pkgdown.yaml | 2 +- .github/workflows/test-coverage.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 52d7d472d..d4ba90e02 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -65,7 +65,7 @@ jobs: - name: Install TensorFlow run: | reticulate::conda_create('r-reticulate', packages = c('python==3.6.9')) - tensorflow::install_tensorflow(version='1.14.0') + tensorflow::install_tensorflow(version='2.7.0') shell: Rscript {0} - uses: r-lib/actions/check-r-package@v2 diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index 4043e37c5..588aa3669 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -34,7 +34,7 @@ jobs: pak::pkg_install('rstudio/reticulate') reticulate::install_miniconda() reticulate::conda_create('r-reticulate', packages = c('python==3.6.9')) - tensorflow::install_tensorflow(version='1.14.0') + tensorflow::install_tensorflow(version='2.7.0') shell: Rscript {0} - name: Install package diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 1dbe389ae..dc3fa5544 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -31,7 +31,7 @@ jobs: pak::pkg_install('rstudio/reticulate') reticulate::install_miniconda() reticulate::conda_create('r-reticulate', packages = c('python==3.6.9')) - tensorflow::install_tensorflow(version='1.14.0') + tensorflow::install_tensorflow(version='2.7.0') shell: Rscript {0} - name: Test coverage From 5b7e470742e04a2498c43c6939ca2b68f9378a43 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 09:44:20 -0800 Subject: [PATCH 07/26] Add old tensorflow version GHA --- .github/workflows/old-tensorflow.yaml | 79 +++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 .github/workflows/old-tensorflow.yaml diff --git a/.github/workflows/old-tensorflow.yaml b/.github/workflows/old-tensorflow.yaml new file mode 100644 index 000000000..b70bcf661 --- /dev/null +++ b/.github/workflows/old-tensorflow.yaml @@ -0,0 +1,79 @@ +# Workflow derived from https://github.com/r-lib/actions/tree/master/examples +# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help +# +# NOTE: This workflow is overkill for most R packages and +# check-standard.yaml is likely a better choice. +# usethis::use_github_action("check-standard") will install it. +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +name: R-CMD-check + +jobs: + R-CMD-check: + runs-on: ${{ matrix.config.os }} + + name: ${{ matrix.config.os }} (${{ matrix.config.r }}) + + strategy: + fail-fast: false + matrix: + config: + - {os: windows-latest, r: 'release'} + # Use older ubuntu to maximise backward compatibility + - {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'} + + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + R_KEEP_PKG_SOURCE: yes + CXX14: g++ + CXX14STD: -std=c++1y + CXX14FLAGS: -Wall -g -02 + + steps: + - uses: actions/checkout@v2 + + - uses: r-lib/actions/setup-pandoc@v2 + + - uses: r-lib/actions/setup-r@v2 + with: + r-version: ${{ matrix.config.r }} + http-user-agent: ${{ matrix.config.http-user-agent }} + use-public-rspm: true + + - uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: rcmdcheck + + - name: Install Miniconda + run: | + pak::pkg_install('rstudio/reticulate') + reticulate::install_miniconda() + shell: Rscript {0} + + - name: Find Miniconda on macOS + if: runner.os == 'macOS' + run: echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile + + - name: Install TensorFlow + run: | + reticulate::conda_create('r-reticulate', packages = c('python==3.6.9')) + tensorflow::install_tensorflow(version='1.14.0') + shell: Rscript {0} + + - uses: r-lib/actions/check-r-package@v2 + + - name: Show testthat output + if: always() + run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true + shell: bash + + - name: Upload check results + if: failure() + uses: actions/upload-artifact@main + with: + name: ${{ runner.os }}-r${{ matrix.config.r }}-results + path: check From b0f3bc759d927f24d7faea240cbd4388d8795aa6 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 13:23:58 -0800 Subject: [PATCH 08/26] Set seeds for tensorflow --- tests/testthat/test_linear_reg_keras.R | 2 ++ tests/testthat/test_logistic_reg_keras.R | 2 ++ tests/testthat/test_multinom_reg_keras.R | 2 ++ 3 files changed, 6 insertions(+) diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index a0625a43c..3fcf1ac59 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -29,6 +29,7 @@ test_that('model fitting', { skip_if_not_installed("keras") set.seed(257) + tensorflow::tf$random$set_seed(257) expect_error( fit1 <- fit_xy( @@ -41,6 +42,7 @@ test_that('model fitting', { ) set.seed(257) + tensorflow::tf$random$set_seed(257) expect_error( fit2 <- fit_xy( diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 4eb8f8614..27d35110d 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -155,6 +155,7 @@ test_that('classification probabilities', { library(keras) set.seed(257) + tensorflow::tf$random$set_seed(257) lr_fit <- fit_xy( basic_mod, @@ -171,6 +172,7 @@ test_that('classification probabilities', { expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) set.seed(257) + tensorflow::tf$random$set_seed(257) plrfit <- fit_xy( reg_mod, diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index c8a59ace8..a37b53567 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -151,6 +151,7 @@ test_that('classification probabilities', { library(keras) set.seed(257) + tensorflow::tf$random$set_seed(257) lr_fit <- fit_xy( basic_mod, @@ -168,6 +169,7 @@ test_that('classification probabilities', { expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) set.seed(257) + tensorflow::tf$random$set_seed(257) plrfit <- fit_xy( reg_mod, From d0c751a36062927e8e9b19b9188d98e8d2f5c613 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 14:19:14 -0800 Subject: [PATCH 09/26] only set tensorflow seed when you need it --- tests/testthat/test_linear_reg_keras.R | 8 ++++++-- tests/testthat/test_logistic_reg_keras.R | 8 ++++++-- tests/testthat/test_multinom_reg_keras.R | 8 ++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index 3fcf1ac59..254b0125a 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -29,7 +29,9 @@ test_that('model fitting', { skip_if_not_installed("keras") set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } expect_error( fit1 <- fit_xy( @@ -42,7 +44,9 @@ test_that('model fitting', { ) set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } expect_error( fit2 <- fit_xy( diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 27d35110d..b83d44659 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -155,7 +155,9 @@ test_that('classification probabilities', { library(keras) set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } lr_fit <- fit_xy( basic_mod, @@ -172,7 +174,9 @@ test_that('classification probabilities', { expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } plrfit <- fit_xy( reg_mod, diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index a37b53567..c133a5a06 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -151,7 +151,9 @@ test_that('classification probabilities', { library(keras) set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } lr_fit <- fit_xy( basic_mod, @@ -169,7 +171,9 @@ test_that('classification probabilities', { expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) set.seed(257) - tensorflow::tf$random$set_seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } plrfit <- fit_xy( reg_mod, From 742c5f930c4fcf44c7441635b38a57f32aa1d290 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jan 2022 14:37:43 -0800 Subject: [PATCH 10/26] conditionally set seed in tensorflow by tensorflow version --- tests/testthat/test_linear_reg_keras.R | 4 ++++ tests/testthat/test_logistic_reg_keras.R | 4 ++++ tests/testthat/test_multinom_reg_keras.R | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index 254b0125a..e61fbfdff 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -31,6 +31,8 @@ test_that('model fitting', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } expect_error( fit1 <- @@ -46,6 +48,8 @@ test_that('model fitting', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } expect_error( fit2 <- diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index b83d44659..2d02a620f 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -157,6 +157,8 @@ test_that('classification probabilities', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } lr_fit <- fit_xy( @@ -176,6 +178,8 @@ test_that('classification probabilities', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } plrfit <- fit_xy( diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index c133a5a06..0d5063417 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -153,6 +153,8 @@ test_that('classification probabilities', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } lr_fit <- fit_xy( @@ -173,6 +175,8 @@ test_that('classification probabilities', { set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) } plrfit <- fit_xy( From 243df0f2b079e82d7ed7a850221bf8dc750194be Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Tue, 1 Feb 2022 16:14:18 -0800 Subject: [PATCH 11/26] do conditional check innside keras_predict_* functions as well --- R/mlp.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/mlp.R b/R/mlp.R index d62a65784..8485eb7b0 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -444,7 +444,7 @@ reformat_torch_num <- function(results, object) { #' @export #' @keywords internal keras_predict_classes <- function(object, x) { - if (tensorflow::tf_version() >= package_version("2.6")) { + if (tensorflow::tf_version() >= package_version("2.0")) { object %>% predict(x) %>% keras::k_argmax() %>% as.integer() } else { keras::predict_classes(object, x) @@ -457,7 +457,7 @@ keras_predict_classes <- function(object, x) { #' @export #' @keywords internal keras_predict_proba <- function(object, x) { - if (tensorflow::tf_version() >= package_version("2.6")) { + if (tensorflow::tf_version() >= package_version("2.0")) { object %>% predict(x) } else { keras::predict_proba(object, x) From fa76e09c81224a92e397874c1cf62c0400080f29 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Feb 2022 18:11:52 -0800 Subject: [PATCH 12/26] Add missing set_seed to keras logistic reg test --- tests/testthat/test_logistic_reg_keras.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 2d02a620f..e8d50295d 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -41,7 +41,11 @@ test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") - set.seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) + } expect_error( fit1 <- fit_xy( @@ -53,7 +57,11 @@ test_that('model fitting', { regexp = NA ) - set.seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) + } expect_error( fit2 <- fit_xy( From f03a7fe4ef6dcc0638a2056e37a84a29f63f57af Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Feb 2022 18:34:03 -0800 Subject: [PATCH 13/26] seperate out old-tensorflow GHA --- .github/workflows/old-tensorflow.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/old-tensorflow.yaml b/.github/workflows/old-tensorflow.yaml index b70bcf661..ed6c9f6cf 100644 --- a/.github/workflows/old-tensorflow.yaml +++ b/.github/workflows/old-tensorflow.yaml @@ -10,10 +10,10 @@ on: pull_request: branches: [main, master] -name: R-CMD-check +name: old-tensorflow jobs: - R-CMD-check: + old-tensorflow: runs-on: ${{ matrix.config.os }} name: ${{ matrix.config.os }} (${{ matrix.config.r }}) From d0b85346b101199705694c2c896057e59a2e4f0c Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Feb 2022 20:11:46 -0800 Subject: [PATCH 14/26] add last missing tensorflow set_seed --- tests/testthat/test_multinom_reg_keras.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index 0d5063417..71d68a688 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -37,7 +37,11 @@ test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") - set.seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) + } expect_error( fit1 <- fit_xy( @@ -49,7 +53,11 @@ test_that('model fitting', { regexp = NA ) - set.seed(257) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(257) + } else { + tensorflow::tf$random$set_random_seed(257) + } expect_error( fit2 <- fit_xy( From 372f78b6bb13da5a7d6906d32fd1068ddf441b61 Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Wed, 2 Feb 2022 20:21:15 -0800 Subject: [PATCH 15/26] you need tensorflow AND R seed... --- tests/testthat/test_logistic_reg_keras.R | 2 ++ tests/testthat/test_multinom_reg_keras.R | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index e8d50295d..8ce3e751c 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -41,6 +41,7 @@ test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") + set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) } else { @@ -57,6 +58,7 @@ test_that('model fitting', { regexp = NA ) + set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) } else { diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index 71d68a688..fd3dec354 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -37,6 +37,7 @@ test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") + set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) } else { @@ -53,6 +54,7 @@ test_that('model fitting', { regexp = NA ) + set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(257) } else { From 54466c29e474c3543f95d0205b5322968f698e3a Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 11:49:17 -0800 Subject: [PATCH 16/26] use keras version as switch --- R/mlp.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/mlp.R b/R/mlp.R index 8485eb7b0..a6bf684d5 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -444,7 +444,7 @@ reformat_torch_num <- function(results, object) { #' @export #' @keywords internal keras_predict_classes <- function(object, x) { - if (tensorflow::tf_version() >= package_version("2.0")) { + if (utils::packageVersion("keras") >= package_version("2.6")) { object %>% predict(x) %>% keras::k_argmax() %>% as.integer() } else { keras::predict_classes(object, x) @@ -457,7 +457,7 @@ keras_predict_classes <- function(object, x) { #' @export #' @keywords internal keras_predict_proba <- function(object, x) { - if (tensorflow::tf_version() >= package_version("2.0")) { + if (utils::packageVersion("keras") >= package_version("2.6")) { object %>% predict(x) } else { keras::predict_proba(object, x) From 121f31be1c90018645ce6d2344f1e9cf8be29a7c Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 15:32:35 -0800 Subject: [PATCH 17/26] Conditionally transform predictions depending on tensorflow version --- R/mlp.R | 8 +++++++- tests/testthat/test_mlp_keras.R | 19 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/R/mlp.R b/R/mlp.R index a6bf684d5..6470560b1 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -445,7 +445,13 @@ reformat_torch_num <- function(results, object) { #' @keywords internal keras_predict_classes <- function(object, x) { if (utils::packageVersion("keras") >= package_version("2.6")) { - object %>% predict(x) %>% keras::k_argmax() %>% as.integer() + preds <- predict(object, x) + if (tensorflow::tf_version() <= package_version("2.0.0")) { + # -1 to assign with keras' zero indexing + apply(preds, 1, which.max) - 1 + } else { + preds %>% keras::k_argmax() %>% as.integer() + } } else { keras::predict_classes(object, x) } diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 04e210897..5d8e776ae 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -74,7 +74,14 @@ test_that('keras classification prediction', { control = ctrl ) - xy_pred <- predict(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) %>% keras::k_argmax() %>% as.integer() + xy_pred <- predict(xy_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + if (tensorflow::tf_version() <= package_version("2.0.0")) { + # -1 to assign with keras' zero indexing + xy_pred <- apply(xy_pred, 1, which.max) - 1 + } else { + xy_pred <- xy_pred %>% keras::k_argmax() %>% as.integer() + } + xy_pred <- factor(levels(hpc$class)[xy_pred + 1], levels = levels(hpc$class)) expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -87,7 +94,15 @@ test_that('keras classification prediction', { control = ctrl ) - form_pred <- predict(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) %>% keras::k_argmax() %>% as.integer() + + form_pred <- predict(form_fit$fit, x = as.matrix(hpc[1:8, num_pred])) + if (tensorflow::tf_version() <= package_version("2.0.0")) { + # -1 to assign with keras' zero indexing + form_pred <- apply(form_pred, 1, which.max) - 1 + } else { + form_pred <- form_pred %>% keras::k_argmax() %>% as.integer() + } + form_pred <- factor(levels(hpc$class)[form_pred + 1], levels = levels(hpc$class)) expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) From 14b5fa3b3b16d9e7a4fa228344a98e82ac23e5d1 Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 16:37:37 -0800 Subject: [PATCH 18/26] skip test if tensorflow version can't be found --- tests/testthat/test_linear_reg_keras.R | 2 ++ tests/testthat/test_logistic_reg_keras.R | 3 +++ tests/testthat/test_mlp_keras.R | 6 ++++++ tests/testthat/test_multinom_reg_keras.R | 3 +++ 4 files changed, 14 insertions(+) diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index e61fbfdff..74fceaad6 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -27,6 +27,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { @@ -104,6 +105,7 @@ test_that('model fitting', { test_that('regression prediction', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 8ce3e751c..aa0d3cce8 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -40,6 +40,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { @@ -117,6 +118,7 @@ test_that('model fitting', { test_that('classification prediction', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) @@ -161,6 +163,7 @@ test_that('classification prediction', { test_that('classification probabilities', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 5d8e776ae..3663562d7 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -22,6 +22,7 @@ nn_dat <- read.csv("nnet_test.txt") test_that('keras execution, classification', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) expect_error( res <- parsnip::fit( @@ -65,6 +66,7 @@ test_that('keras execution, classification', { test_that('keras classification prediction', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) xy_fit <- parsnip::fit_xy( @@ -113,6 +115,7 @@ test_that('keras classification prediction', { test_that('keras classification probabilities', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) xy_fit <- parsnip::fit_xy( hpc_keras, @@ -163,6 +166,7 @@ bad_keras_reg <- test_that('keras execution, regression', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) expect_error( res <- parsnip::fit( @@ -190,6 +194,7 @@ test_that('keras execution, regression', { test_that('keras regression prediction', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) xy_fit <- parsnip::fit_xy( mlp(mode = "regression", hidden_units = 2, epochs = 500, penalty = .1) %>% @@ -223,6 +228,7 @@ test_that('keras regression prediction', { test_that('multivariate nnet formula', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) nnet_form <- mlp(mode = "regression", hidden_units = 3, penalty = 0.01) %>% diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index fd3dec354..c1f7a16e6 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -36,6 +36,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) set.seed(257) if (tensorflow::tf_version() >= package_version("2.0")) { @@ -113,6 +114,7 @@ test_that('model fitting', { test_that('classification prediction', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) @@ -157,6 +159,7 @@ test_that('classification prediction', { test_that('classification probabilities', { skip_on_cran() skip_if_not_installed("keras") + skip_if(is.null(tensorflow::tf_version())) library(keras) From 0150affe367f0611548a96349a4fccfab2dcd91f Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 16:52:04 -0800 Subject: [PATCH 19/26] refactor keras_predict_classes to avoid post function --- R/logistic_reg_data.R | 6 ++---- R/mlp.R | 9 +++++---- R/mlp_data.R | 6 ++---- R/multinom_reg_data.R | 6 ++---- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index d2a49653b..6f0cf6f87 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -438,13 +438,11 @@ set_pred( type = "class", value = list( pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, + post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = list( - object = quote(object$fit), + object = quote(object), x = quote(as.matrix(new_data)) ) ) diff --git a/R/mlp.R b/R/mlp.R index 6470560b1..14b047ea6 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -445,16 +445,17 @@ reformat_torch_num <- function(results, object) { #' @keywords internal keras_predict_classes <- function(object, x) { if (utils::packageVersion("keras") >= package_version("2.6")) { - preds <- predict(object, x) + preds <- predict(object$fit, x) if (tensorflow::tf_version() <= package_version("2.0.0")) { # -1 to assign with keras' zero indexing - apply(preds, 1, which.max) - 1 + index <- apply(preds, 1, which.max) - 1 } else { - preds %>% keras::k_argmax() %>% as.integer() + index <- preds %>% keras::k_argmax() %>% as.integer() } } else { - keras::predict_classes(object, x) + index <- keras::predict_classes(object$fit, x) } + object$lvl[index + 1] } #' Wrapper for keras class probability predictions diff --git a/R/mlp_data.R b/R/mlp_data.R index 5d11199c1..df6d079fc 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -146,13 +146,11 @@ set_pred( type = "class", value = list( pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, + post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = list( - object = quote(object$fit), + object = quote(object), x = quote(as.matrix(new_data)) ) ) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 8719dd5a2..65139889d 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -234,12 +234,10 @@ set_pred( type = "class", value = list( pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, + post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), args = - list(object = quote(object$fit), + list(object = quote(object), x = quote(as.matrix(new_data))) ) ) From c2d34c4726bd237b58a36d80b4abeee62fbb108e Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 17:11:18 -0800 Subject: [PATCH 20/26] Add keras_set_seed function --- NAMESPACE | 1 + R/mlp.R | 21 ++++++++++++++-- man/keras_set_seed.Rd | 17 +++++++++++++ tests/testthat/test_linear_reg_keras.R | 16 +++--------- tests/testthat/test_logistic_reg_keras.R | 32 ++++++------------------ tests/testthat/test_multinom_reg_keras.R | 32 ++++++------------------ 6 files changed, 57 insertions(+), 62 deletions(-) create mode 100644 man/keras_set_seed.Rd diff --git a/NAMESPACE b/NAMESPACE index 5f5bb0469..374c5cf0f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -208,6 +208,7 @@ export(is_varying) export(keras_mlp) export(keras_predict_classes) export(keras_predict_proba) +export(keras_set_seed) export(linear_reg) export(logistic_reg) export(make_call) diff --git a/R/mlp.R b/R/mlp.R index 14b047ea6..7175bc597 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -443,7 +443,7 @@ reformat_torch_num <- function(results, object) { #' @param x A data set. #' @export #' @keywords internal -keras_predict_classes <- function(object, x) { +keras_predict_classes <- function(object, x) { if (utils::packageVersion("keras") >= package_version("2.6")) { preds <- predict(object$fit, x) if (tensorflow::tf_version() <= package_version("2.0.0")) { @@ -463,11 +463,28 @@ keras_predict_classes <- function(object, x) { #' @param x A data set. #' @export #' @keywords internal -keras_predict_proba <- function(object, x) { +keras_predict_proba <- function(object, x) { if (utils::packageVersion("keras") >= package_version("2.6")) { object %>% predict(x) } else { keras::predict_proba(object, x) } +} +#' Set seed in R and TensorFlow at the same time +#' +#' Some Keras models requires seeds to be set in both R and TensorFlow to +#' achieve reproducible results. This function sets these seeds at the same +#' time using version appropriate functions. +#' +#' @param seed 1 integer value. +#' @export +#' @keywords internal +keras_set_seed <- function(seed) { + set.seed(seed) + if (tensorflow::tf_version() >= package_version("2.0")) { + tensorflow::tf$random$set_seed(seed) + } else { + tensorflow::tf$random$set_random_seed(seed) + } } diff --git a/man/keras_set_seed.Rd b/man/keras_set_seed.Rd new file mode 100644 index 000000000..899abbf7d --- /dev/null +++ b/man/keras_set_seed.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlp.R +\name{keras_set_seed} +\alias{keras_set_seed} +\title{Set seed in R and TensorFlow at the same time} +\usage{ +keras_set_seed(seed) +} +\arguments{ +\item{seed}{1 integer value.} +} +\description{ +Some Keras models requires seeds to be set in both R and TensorFlow to +achieve reproducible results. This function sets these seeds at the same +time using version appropriate functions. +} +\keyword{internal} diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index 74fceaad6..679758fac 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -29,12 +29,8 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit1 <- fit_xy( @@ -46,12 +42,8 @@ test_that('model fitting', { regexp = NA ) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit2 <- fit_xy( diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index aa0d3cce8..21d2ca447 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -42,12 +42,8 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit1 <- fit_xy( @@ -59,12 +55,8 @@ test_that('model fitting', { regexp = NA ) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit2 <- fit_xy( @@ -167,12 +159,8 @@ test_that('classification probabilities', { library(keras) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + lr_fit <- fit_xy( basic_mod, @@ -188,12 +176,8 @@ test_that('classification probabilities', { parsnip_pred <- predict(lr_fit, te_dat[, -1], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + plrfit <- fit_xy( reg_mod, diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index c1f7a16e6..baab531ba 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -38,12 +38,8 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit1 <- fit_xy( @@ -55,12 +51,8 @@ test_that('model fitting', { regexp = NA ) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + expect_error( fit2 <- fit_xy( @@ -163,12 +155,8 @@ test_that('classification probabilities', { library(keras) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + lr_fit <- fit_xy( basic_mod, @@ -185,12 +173,8 @@ test_that('classification probabilities', { parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - set.seed(257) - if (tensorflow::tf_version() >= package_version("2.0")) { - tensorflow::tf$random$set_seed(257) - } else { - tensorflow::tf$random$set_random_seed(257) - } + keras_set_seed(257) + plrfit <- fit_xy( reg_mod, From f42ed424fd5fc092437b70ab7521ea1ad0fd50fe Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 17:13:02 -0800 Subject: [PATCH 21/26] remove remotes --- DESCRIPTION | 3 --- 1 file changed, 3 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2d13a3c1c..0df8a900d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -85,6 +85,3 @@ Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.2 -Remotes: - tidymodels/dials, - tidymodels/hardhat From 89288c6fe742b8bcdd8636c0f6a2bc752409f4ac Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 17:16:42 -0800 Subject: [PATCH 22/26] rename keras_set_seed --- NAMESPACE | 2 +- R/mlp.R | 2 +- man/{keras_set_seed.Rd => set_tf_seed.Rd} | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) rename man/{keras_set_seed.Rd => set_tf_seed.Rd} (86%) diff --git a/NAMESPACE b/NAMESPACE index 374c5cf0f..98af14539 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -208,7 +208,6 @@ export(is_varying) export(keras_mlp) export(keras_predict_classes) export(keras_predict_proba) -export(keras_set_seed) export(linear_reg) export(logistic_reg) export(make_call) @@ -272,6 +271,7 @@ export(set_model_engine) export(set_model_mode) export(set_new_model) export(set_pred) +export(set_tf_seed) export(show_call) export(show_engines) export(show_fit) diff --git a/R/mlp.R b/R/mlp.R index 7175bc597..070ece7eb 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -480,7 +480,7 @@ keras_predict_proba <- function(object, x) { #' @param seed 1 integer value. #' @export #' @keywords internal -keras_set_seed <- function(seed) { +set_tf_seed <- function(seed) { set.seed(seed) if (tensorflow::tf_version() >= package_version("2.0")) { tensorflow::tf$random$set_seed(seed) diff --git a/man/keras_set_seed.Rd b/man/set_tf_seed.Rd similarity index 86% rename from man/keras_set_seed.Rd rename to man/set_tf_seed.Rd index 899abbf7d..1c89f5f00 100644 --- a/man/keras_set_seed.Rd +++ b/man/set_tf_seed.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/mlp.R -\name{keras_set_seed} -\alias{keras_set_seed} +\name{set_tf_seed} +\alias{set_tf_seed} \title{Set seed in R and TensorFlow at the same time} \usage{ -keras_set_seed(seed) +set_tf_seed(seed) } \arguments{ \item{seed}{1 integer value.} From e2383d37fa6db5202ba08033f22963fed4e2c9bb Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 17:19:19 -0800 Subject: [PATCH 23/26] add news --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 37b9422a4..bd5a6c529 100644 --- a/NEWS.md +++ b/NEWS.md @@ -35,6 +35,8 @@ * `set_dependency()` now allows developers to create package requirements that are specific to the model's mode (#604). +* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596). + # parsnip 0.1.7 From b1e971384f4bafad95b795f2448b9cfbbc674383 Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Thu, 3 Feb 2022 18:10:10 -0800 Subject: [PATCH 24/26] rename to set_tf_seed in all tests --- tests/testthat/test_linear_reg_keras.R | 4 ++-- tests/testthat/test_logistic_reg_keras.R | 8 ++++---- tests/testthat/test_multinom_reg_keras.R | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index 679758fac..c81373493 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -29,7 +29,7 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit1 <- @@ -42,7 +42,7 @@ test_that('model fitting', { regexp = NA ) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit2 <- diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 21d2ca447..29ea8e883 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -42,7 +42,7 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit1 <- @@ -55,7 +55,7 @@ test_that('model fitting', { regexp = NA ) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit2 <- @@ -159,7 +159,7 @@ test_that('classification probabilities', { library(keras) - keras_set_seed(257) + set_tf_seed(257) lr_fit <- fit_xy( @@ -176,7 +176,7 @@ test_that('classification probabilities', { parsnip_pred <- predict(lr_fit, te_dat[, -1], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - keras_set_seed(257) + set_tf_seed(257) plrfit <- fit_xy( diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index baab531ba..8d13585c0 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -38,7 +38,7 @@ test_that('model fitting', { skip_if_not_installed("keras") skip_if(is.null(tensorflow::tf_version())) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit1 <- @@ -51,7 +51,7 @@ test_that('model fitting', { regexp = NA ) - keras_set_seed(257) + set_tf_seed(257) expect_error( fit2 <- @@ -155,7 +155,7 @@ test_that('classification probabilities', { library(keras) - keras_set_seed(257) + set_tf_seed(257) lr_fit <- fit_xy( @@ -173,7 +173,7 @@ test_that('classification probabilities', { parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - keras_set_seed(257) + set_tf_seed(257) plrfit <- fit_xy( From af9cbe311692aeb12702a7642db42fb62c4f1ab2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 8 Feb 2022 10:42:40 -0800 Subject: [PATCH 25/26] adjust "R CMD Check" and "old tensorflow" GHA --- .github/workflows/R-CMD-check.yaml | 2 +- .github/workflows/old-tensorflow.yaml | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 13efd7f5a..dcb5e3fac 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -66,7 +66,7 @@ jobs: - name: Install TensorFlow run: | - tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL) + tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL) shell: Rscript {0} - uses: r-lib/actions/check-r-package@v2 diff --git a/.github/workflows/old-tensorflow.yaml b/.github/workflows/old-tensorflow.yaml index ed6c9f6cf..c2fe0e67e 100644 --- a/.github/workflows/old-tensorflow.yaml +++ b/.github/workflows/old-tensorflow.yaml @@ -9,6 +9,7 @@ on: branches: [main, master] pull_request: branches: [main, master] + workflow_dispatch: name: old-tensorflow @@ -25,7 +26,6 @@ jobs: - {os: windows-latest, r: 'release'} # Use older ubuntu to maximise backward compatibility - {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'} - env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes @@ -48,20 +48,20 @@ jobs: with: extra-packages: rcmdcheck - - name: Install Miniconda - run: | - pak::pkg_install('rstudio/reticulate') - reticulate::install_miniconda() + - name: Install dev reticulate + run: pak::pkg_install('rstudio/reticulate') shell: Rscript {0} - - name: Find Miniconda on macOS - if: runner.os == 'macOS' - run: echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile + - name: Install Miniconda + # conda can fail at downgrading python, so we specify python version in advance + env: + RETICULATE_MINICONDA_PYTHON_VERSION: "3.7" + run: reticulate::install_miniconda() # creates r-reticulate conda env by default + shell: Rscript {0} - name: Install TensorFlow run: | - reticulate::conda_create('r-reticulate', packages = c('python==3.6.9')) - tensorflow::install_tensorflow(version='1.14.0') + tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL) shell: Rscript {0} - uses: r-lib/actions/check-r-package@v2 From 12ce5b85d4a10e3bdda810a1d96fdcb5eb0d11f0 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 8 Feb 2022 11:40:32 -0800 Subject: [PATCH 26/26] don't use keras_predict_proba anymore --- NAMESPACE | 1 - R/logistic_reg_data.R | 2 +- R/mlp.R | 13 ------------- R/mlp_data.R | 2 +- R/multinom_reg_data.R | 2 +- man/keras_predict_proba.Rd | 17 ----------------- 6 files changed, 3 insertions(+), 34 deletions(-) delete mode 100644 man/keras_predict_proba.Rd diff --git a/NAMESPACE b/NAMESPACE index 74daa0f4a..a54c266fa 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -207,7 +207,6 @@ export(has_multi_predict) export(is_varying) export(keras_mlp) export(keras_predict_classes) -export(keras_predict_proba) export(knit_engine_docs) export(linear_reg) export(list_md_problems) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index e72e0cab2..db6f12389 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -460,7 +460,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "parsnip", fun = "keras_predict_proba"), + func = c(fun = "predict"), args = list( object = quote(object$fit), diff --git a/R/mlp.R b/R/mlp.R index e9c6f65ff..444a19ec2 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -456,19 +456,6 @@ keras_predict_classes <- function(object, x) { object$lvl[index + 1] } -#' Wrapper for keras class probability predictions -#' @param object A keras model fit -#' @param x A data set. -#' @export -#' @keywords internal -keras_predict_proba <- function(object, x) { - if (utils::packageVersion("keras") >= package_version("2.6")) { - object %>% predict(x) - } else { - keras::predict_proba(object, x) - } -} - #' Set seed in R and TensorFlow at the same time #' #' Some Keras models requires seeds to be set in both R and TensorFlow to diff --git a/R/mlp_data.R b/R/mlp_data.R index 460e0da33..f5d9e7c4b 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -168,7 +168,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "parsnip", fun = "keras_predict_proba"), + func = c(fun = "predict"), args = list( object = quote(object$fit), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 35d85551e..96188f62c 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -254,7 +254,7 @@ set_pred( x <- as_tibble(x) x }, - func = c(pkg = "parsnip", fun = "keras_predict_proba"), + func = c(fun = "predict"), args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) diff --git a/man/keras_predict_proba.Rd b/man/keras_predict_proba.Rd deleted file mode 100644 index 45e94b4e1..000000000 --- a/man/keras_predict_proba.Rd +++ /dev/null @@ -1,17 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/mlp.R -\name{keras_predict_proba} -\alias{keras_predict_proba} -\title{Wrapper for keras class probability predictions} -\usage{ -keras_predict_proba(object, x) -} -\arguments{ -\item{object}{A keras model fit} - -\item{x}{A data set.} -} -\description{ -Wrapper for keras class probability predictions -} -\keyword{internal}