diff --git a/R/utils.R b/R/utils.R index 60a11db5a..a3bdd44b2 100644 --- a/R/utils.R +++ b/R/utils.R @@ -224,14 +224,13 @@ keras_array <- function(x, dtype = NULL) { if ( tf_version() >= "1.12" && ( - tensorflow::tf$contrib$framework$is_tensor(x) || - is.list(x) && all(vapply(x, tensorflow::tf$contrib$framework$is_tensor, logical(1))) + is_keras_tensor(x) || is.list(x) && all(vapply(x, is_keras_tensor, logical(1))) ) ) { return(x) } } else { - if ((keras_version() >= "2.2.0") && k_is_tensor(x)) { + if ((keras_version() >= "2.2.0") && is_keras_tensor(x)) { return(x) } } @@ -381,3 +380,11 @@ as_shape <- function(x) { as.integer(d) }) } + +is_keras_tensor <- function(x) { + if (is_tensorflow_implementation()) { + if (tensorflow::tf_version() >= "2.0") tensorflow::tf$is_tensor(x) else tensorflow::tf$contrib$framework$is_tensor(x) + } else { + k_is_tensor(x) + } +}