diff --git a/R/install.R b/R/install.R index 62e83a312..0c7e33f48 100644 --- a/R/install.R +++ b/R/install.R @@ -170,6 +170,8 @@ use_backend <- function(backend, gpu = NA) { reticulate::import("os")$environ$update(list(KERAS_BACKEND = backend)) } + set_envvar("UV_CONSTRAINT", pkg_file("keras-constraints.txt"), + action = "append", sep = " ", unique = TRUE) switch( paste0(get_os(), "_", backend), @@ -222,11 +224,11 @@ use_backend <- function(backend, gpu = NA) { gpu <- has_gpu() if (gpu) { + uv_unset_override_tf_cpu() py_require(action = "remove", c("tensorflow", "tensorflow-cpu")) py_require("tensorflow[and-cuda]") } else { - py_require(action = "remove", c("tensorflow", "tensorflow[and-cuda]")) - py_require("tensorflow-cpu") + uv_set_override_tf_cpu() } }, @@ -234,6 +236,7 @@ use_backend <- function(backend, gpu = NA) { py_require(action = "remove", c("tensorflow", "tensorflow[and-cuda]", "jax[cuda12]", "jax[cpu]")) + uv_set_override_tf_cpu() if (is.na(gpu)) gpu <- has_gpu() @@ -248,6 +251,7 @@ use_backend <- function(backend, gpu = NA) { Linux_torch = { py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove") + uv_set_override_tf_cpu() if (is.na(gpu)) gpu <- has_gpu() @@ -264,6 +268,7 @@ use_backend <- function(backend, gpu = NA) { }, Linux_numpy = { + uv_set_override_tf_cpu() py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove") py_require(c("tensorflow-cpu", "numpy", "jax[cpu]")) }, @@ -301,8 +306,62 @@ use_backend <- function(backend, gpu = NA) { invisible(backend) } +set_envvar <- function( + name, + value, + action = c("replace", "append", "prepend"), + sep = .Platform$path.sep, + unique = FALSE +) { + old <- Sys.getenv(name, NA) + + if (is.null(value) || is.na(value)) { + Sys.unsetenv(name) + return(invisible(old)) + } + + if (!is.na(old)) { + value <- switch( + match.arg(action), + replace = value, + append = paste(old, value, sep = sep), + prepend = paste(value, old, sep = sep) + ) + if (unique) { + value <- unique(unlist(strsplit(value, sep, fixed = TRUE))) + value <- paste0(value, collapse = sep) + } + } + + value <- list(value) + names(value) <- name + do.call(Sys.setenv, value) + invisible(old) +} +uv_set_override_tf_cpu <- function() { + py_require(action = "remove", c( + "tensorflow", "tensorflow[and-cuda]", "tensorflow-cpu", + "tensorflow-metal", "tensorflow-macos" + )) + py_require(if (is_linux()) "tensorflow-cpu" else "tensorflow") + set_envvar("UV_OVERRIDE", pkg_file("tf-cpu-override.txt"), + action = "append", sep = " ", unique = TRUE) +} +uv_unset_override_tf_cpu <- function() { + override <- Sys.getenv("UV_OVERRIDE", NA) + if (is.na(override)) return() + cpu_override <- pkg_file("tf-cpu-override.txt") + if (override == cpu_override) { + Sys.unsetenv(override) + } else { + new <- gsub(cpu_override, "", override, fixed = TRUE) + new <- gsub(" +", " ", new) + Sys.setenv("UV_OVERRIDE" = new) + } + invisible(override) +} get_os <- function() { if (is_windows()) "Windows" else if (is_mac_arm64()) "macOS" else "Linux" @@ -321,6 +380,13 @@ is_keras_loaded <- function() { !exists("module", envir = keras) } +pkg_file <- function(..., package = "keras3") { + path <- system.file(..., package = "keras3", mustWork = TRUE) + if(is_windows()) + path <- utils::shortPathName(path) + path +} + has_gpu <- function() { diff --git a/inst/keras-constraints.txt b/inst/keras-constraints.txt new file mode 100644 index 000000000..a2780d921 --- /dev/null +++ b/inst/keras-constraints.txt @@ -0,0 +1,9 @@ + +# unconstrained keras-hub in a larger requirements list might resolve to v0.19.0, +# which is over a year old and generally what people want, and arguably a bug with `uv`. +# This is a workaround to nudge uv to resolve the latest keras-hub. +keras-hub>0.19.0 + + +# tensorflow-text 2.19.* fails to load with tensorflow-cpu>=2.19.0 +tensorflow-cpu==2.18.* diff --git a/inst/tf-cpu-override.txt b/inst/tf-cpu-override.txt new file mode 100644 index 000000000..189aaddab --- /dev/null +++ b/inst/tf-cpu-override.txt @@ -0,0 +1 @@ +tensorflow; sys_platform == "never" diff --git a/tools/install.R b/tools/install.R new file mode 100755 index 000000000..e1bf00e39 --- /dev/null +++ b/tools/install.R @@ -0,0 +1,3 @@ +#!/usr/bin/Rscript + +remotes::install_local(force = TRUE)