diff --git a/NAMESPACE b/NAMESPACE index a284e4f0e..c5b5391f0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -43,6 +43,7 @@ S3method(evaluate,keras.src.models.model.Model) S3method(fit,keras.src.models.model.Model) S3method(format,keras.src.models.model.Model) S3method(format,keras_shape) +S3method(pillar::type_sum,keras.src.backend.jax.core.JaxVariable) S3method(pillar::type_sum,keras.src.backend.jax.core.Variable) S3method(plot,keras.src.models.model.Model) S3method(plot,keras_training_history) @@ -62,6 +63,7 @@ S3method(r_to_py,keras_shape) S3method(str,jax.Array) S3method(str,jaxlib._jax.ArrayImpl) S3method(str,jaxlib.xla_extension.ArrayImpl) +S3method(str,keras.src.backend.jax.core.JaxVariable) S3method(str,keras.src.backend.jax.core.Variable) S3method(summary,keras.src.models.model.Model) S3method(tensorflow::export_savedmodel,keras.src.models.model.Model) diff --git a/R/jax-methods.R b/R/jax-methods.R index 361d6aa43..b27dce0f5 100644 --- a/R/jax-methods.R +++ b/R/jax-methods.R @@ -84,3 +84,13 @@ type_sum.keras.src.backend.jax.core.Variable <- function(x) { x <- sub("shape=\\((None|[[:digit:]]+),\\)", "shape=(\\1)", x) x } + +## new S3 class names in Keras 3.11 +#' @exportS3Method str keras.src.backend.jax.core.JaxVariable +str.keras.src.backend.jax.core.JaxVariable <- str.keras.src.backend.jax.core.Variable + +#' @exportS3Method pillar::type_sum keras.src.backend.jax.core.JaxVariable +type_sum.keras.src.backend.jax.core.JaxVariable <- type_sum.keras.src.backend.jax.core.Variable + +# "keras.src.backend.Variable" too? +# "keras.src.backend.common.variables.Variable" too?