Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1c9e2ba
Add keras_predict_classes to replace use of keras::predict_classes
EmilHvitfeldt Jan 26, 2022
0174f98
use keras_predict_classes
EmilHvitfeldt Jan 26, 2022
06449a7
update mlp_keras reference calculations in tests
EmilHvitfeldt Jan 26, 2022
dafc0da
move from predict_proba() to predict()
EmilHvitfeldt Jan 26, 2022
79f4d72
Create conditional tensorflow checking
EmilHvitfeldt Jan 31, 2022
e609d64
use newer version of tensorflow in GHA
EmilHvitfeldt Jan 31, 2022
5b7e470
Add old tensorflow version GHA
EmilHvitfeldt Jan 31, 2022
b0f3bc7
Set seeds for tensorflow
EmilHvitfeldt Jan 31, 2022
d0c751a
only set tensorflow seed when you need it
EmilHvitfeldt Jan 31, 2022
742c5f9
conditionally set seed in tensorflow by tensorflow version
EmilHvitfeldt Jan 31, 2022
243df0f
do conditional check innside keras_predict_* functions as well
EmilHvitfeldt Feb 2, 2022
fa76e09
Add missing set_seed to keras logistic reg test
EmilHvitfeldt Feb 2, 2022
a03cf8b
Merge branch 'update-keras' of github.com:tidymodels/parsnip into upd…
EmilHvitfeldt Feb 2, 2022
f03a7fe
seperate out old-tensorflow GHA
EmilHvitfeldt Feb 2, 2022
d0b8534
add last missing tensorflow set_seed
EmilHvitfeldt Feb 2, 2022
372f78b
you need tensorflow AND R seed...
EmilHvitfeldt Feb 3, 2022
54466c2
use keras version as switch
EmilHvitfeldt Feb 3, 2022
121f31b
Conditionally transform predictions depending on tensorflow version
EmilHvitfeldt Feb 3, 2022
14b5fa3
skip test if tensorflow version can't be found
EmilHvitfeldt Feb 4, 2022
0150aff
refactor keras_predict_classes to avoid post function
EmilHvitfeldt Feb 4, 2022
c2d34c4
Add keras_set_seed function
EmilHvitfeldt Feb 4, 2022
f42ed42
remove remotes
EmilHvitfeldt Feb 4, 2022
89288c6
rename keras_set_seed
EmilHvitfeldt Feb 4, 2022
e2383d3
add news
EmilHvitfeldt Feb 4, 2022
b1e9713
rename to set_tf_seed in all tests
EmilHvitfeldt Feb 4, 2022
e15d1c4
Merge branch 'main' into update-keras
EmilHvitfeldt Feb 4, 2022
b95b6b6
update from main
topepo Feb 8, 2022
5dcd898
merge main
EmilHvitfeldt Feb 8, 2022
af9cbe3
adjust "R CMD Check" and "old tensorflow" GHA
EmilHvitfeldt Feb 8, 2022
12ce5b8
don't use keras_predict_proba anymore
EmilHvitfeldt Feb 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ jobs:

- name: Install Miniconda
# conda can fail at downgrading python, so we specify python version in advance
env:
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
env:
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
shell: Rscript {0}

- name: Install TensorFlow
run: |
tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL)
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
Expand Down
79 changes: 79 additions & 0 deletions .github/workflows/old-tensorflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
#
# NOTE: This workflow is overkill for most R packages and
# check-standard.yaml is likely a better choice.
# usethis::use_github_action("check-standard") will install it.
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]
workflow_dispatch:

name: old-tensorflow

jobs:
old-tensorflow:
runs-on: ${{ matrix.config.os }}

name: ${{ matrix.config.os }} (${{ matrix.config.r }})

strategy:
fail-fast: false
matrix:
config:
- {os: windows-latest, r: 'release'}
# Use older ubuntu to maximise backward compatibility
- {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'}
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
CXX14: g++
CXX14STD: -std=c++1y
CXX14FLAGS: -Wall -g -02

steps:
- uses: actions/checkout@v2

- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: rcmdcheck

- name: Install dev reticulate
run: pak::pkg_install('rstudio/reticulate')
shell: Rscript {0}

- name: Install Miniconda
# conda can fail at downgrading python, so we specify python version in advance
env:
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
shell: Rscript {0}

- name: Install TensorFlow
run: |
tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2

- name: Show testthat output
if: always()
run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true
shell: bash

- name: Upload check results
if: failure()
uses: actions/upload-artifact@main
with:
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
path: check
2 changes: 1 addition & 1 deletion .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
pak::pkg_install('rstudio/reticulate')
reticulate::install_miniconda()
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
tensorflow::install_tensorflow(version='1.14.0')
tensorflow::install_tensorflow(version='2.7.0')
shell: Rscript {0}

- name: Install package
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pak::pkg_install('rstudio/reticulate')
reticulate::install_miniconda()
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
tensorflow::install_tensorflow(version='1.14.0')
tensorflow::install_tensorflow(version='2.7.0')
shell: Rscript {0}

- name: Test coverage
Expand Down
4 changes: 1 addition & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Suggests:
covr,
dials (>= 0.0.10.9001),
earth,
tensorflow,
ggplot2,
keras,
kernlab,
Expand Down Expand Up @@ -86,6 +87,3 @@ Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2
Remotes:
tidymodels/dials,
tidymodels/hardhat
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ export(glance)
export(has_multi_predict)
export(is_varying)
export(keras_mlp)
export(keras_predict_classes)
export(knit_engine_docs)
export(linear_reg)
export(list_md_problems)
Expand Down Expand Up @@ -271,6 +272,7 @@ export(set_model_engine)
export(set_model_mode)
export(set_new_model)
export(set_pred)
export(set_tf_seed)
export(show_call)
export(show_engines)
export(show_fit)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
* Argument `interval` was added for prediction: For types "survival" and "quantile", estimates for the confidence or prediction interval can be added if available (#615).

* `set_dependency()` now allows developers to create package requirements that are specific to the model's mode (#604).

*
* `varying()` is soft-deprecated in favor of `tune()`.

* `varying_args()` is soft-deprecated in favor of `tune_args()`.

* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596).

# parsnip 0.1.7

## Model Specification Changes
Expand Down
10 changes: 4 additions & 6 deletions R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,11 @@ set_pred(
type = "class",
value = list(
pre = NULL,
post = function(x, object) {
object$lvl[x + 1]
},
func = c(pkg = "keras", fun = "predict_classes"),
post = NULL,
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
args =
list(
object = quote(object$fit),
object = quote(object),
x = quote(as.matrix(new_data))
)
)
Expand All @@ -462,7 +460,7 @@ set_pred(
x <- as_tibble(x)
x
},
func = c(pkg = "keras", fun = "predict_proba"),
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
Expand Down
37 changes: 36 additions & 1 deletion R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,5 +436,40 @@ reformat_torch_num <- function(results, object) {
results
}

#' Wrapper for keras class predictions
#' @param object A keras model fit
#' @param x A data set.
#' @export
#' @keywords internal
keras_predict_classes <- function(object, x) {
if (utils::packageVersion("keras") >= package_version("2.6")) {
preds <- predict(object$fit, x)
if (tensorflow::tf_version() <= package_version("2.0.0")) {
# -1 to assign with keras' zero indexing
index <- apply(preds, 1, which.max) - 1
} else {
index <- preds %>% keras::k_argmax() %>% as.integer()
}
} else {
index <- keras::predict_classes(object$fit, x)
}
object$lvl[index + 1]
}


#' Set seed in R and TensorFlow at the same time
#'
#' Some Keras models requires seeds to be set in both R and TensorFlow to
#' achieve reproducible results. This function sets these seeds at the same
#' time using version appropriate functions.
#'
#' @param seed 1 integer value.
#' @export
#' @keywords internal
set_tf_seed <- function(seed) {
set.seed(seed)
if (tensorflow::tf_version() >= package_version("2.0")) {
tensorflow::tf$random$set_seed(seed)
} else {
tensorflow::tf$random$set_random_seed(seed)
}
}
10 changes: 4 additions & 6 deletions R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,11 @@ set_pred(
type = "class",
value = list(
pre = NULL,
post = function(x, object) {
object$lvl[x + 1]
},
func = c(pkg = "keras", fun = "predict_classes"),
post = NULL,
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
args =
list(
object = quote(object$fit),
object = quote(object),
x = quote(as.matrix(new_data))
)
)
Expand All @@ -170,7 +168,7 @@ set_pred(
x <- as_tibble(x)
x
},
func = c(pkg = "keras", fun = "predict_proba"),
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
Expand Down
10 changes: 4 additions & 6 deletions R/multinom_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ set_pred(
type = "class",
value = list(
pre = NULL,
post = function(x, object) {
object$lvl[x + 1]
},
func = c(pkg = "keras", fun = "predict_classes"),
post = NULL,
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
args =
list(object = quote(object$fit),
list(object = quote(object),
x = quote(as.matrix(new_data)))
)
)
Expand All @@ -256,7 +254,7 @@ set_pred(
x <- as_tibble(x)
x
},
func = c(pkg = "keras", fun = "predict_proba"),
func = c(fun = "predict"),
args =
list(object = quote(object$fit),
x = quote(as.matrix(new_data)))
Expand Down
17 changes: 17 additions & 0 deletions man/keras_predict_classes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/set_tf_seed.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions tests/testthat/test_linear_reg_keras.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE)
test_that('model fitting', {
skip_on_cran()
skip_if_not_installed("keras")
skip_if(is.null(tensorflow::tf_version()))

set_tf_seed(257)

set.seed(257)
expect_error(
fit1 <-
fit_xy(
Expand All @@ -40,7 +42,8 @@ test_that('model fitting', {
regexp = NA
)

set.seed(257)
set_tf_seed(257)

expect_error(
fit2 <-
fit_xy(
Expand Down Expand Up @@ -94,6 +97,7 @@ test_that('model fitting', {
test_that('regression prediction', {
skip_on_cran()
skip_if_not_installed("keras")
skip_if(is.null(tensorflow::tf_version()))

library(keras)

Expand Down
Loading