From 9acb4471fe9ffaf811c99c5d3b37933a6cb3db9d Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Thu, 27 Mar 2025 08:38:48 -0400 Subject: [PATCH 1/3] Capture stderr from gpu-detection code --- R/install.R | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/R/install.R b/R/install.R index db14b3c..ec66682 100644 --- a/R/install.R +++ b/R/install.R @@ -186,12 +186,14 @@ function(method = c("auto", "virtualenv", "conda"), has_nvidia_gpu <- function() { lspci_listed <- tryCatch( - as.logical(length( - system("{ lspci | grep -i nvidia; } 2>/dev/null", intern = TRUE) - )), - warning = function(w) FALSE, # warning emitted by system for non-0 exit status + { + lspci <- system("lspci", intern = TRUE, ignore.stderr = TRUE) + any(grepl("nvidia", lspci, ignore.case = TRUE)) + }, + warning = function(w) FALSE, error = function(e) FALSE ) + if (lspci_listed) return(TRUE) @@ -342,10 +344,10 @@ has_gpu <- function() { # on.exit(options(oop), add = TRUE) lspci_listed <- tryCatch( - as.logical(length( - system("{ lspci | grep -i nvidia; } 2>/dev/null", intern = TRUE) - )), - # warning emitted by system for non-0 exit status + { + lspci <- system("lspci", intern = TRUE, ignore.stderr = TRUE) + any(grepl("nvidia", lspci, ignore.case = TRUE)) + }, warning = function(w) FALSE, error = function(e) FALSE ) From 28d8d698ff8fa9b70431eee9d6343dc845aa9bd0 Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Thu, 27 Mar 2025 08:39:27 -0400 Subject: [PATCH 2/3] export `%*%` method for tensors --- R/generics.R | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/R/generics.R b/R/generics.R index 85ac8e5..77bb300 100644 --- a/R/generics.R +++ b/R/generics.R @@ -883,3 +883,19 @@ py_to_r.keras.src.utils.tracking.TrackedList <- function(x) import("builtins")$l #' @export py_to_r.keras.src.utils.tracking.TrackedSet <- function(x) import("builtins")$list(x) + + +#' @rdname Ops-python-methods +#' @rawNamespace if (getRversion() >= "4.3.0") S3method("%*%",python.builtin.object) +`%*%.tensorflow.tensor` <- function(x, y) { + if (is.atomic(x) && is_tensor(y)) { + if (length(x) > 1L) + x <- as.array(x) + x <- tf$convert_to_tensor(x, dtype = y$dtype) + } else if (is_tensor(x) && is.atomic(y)) { + if (length(y) > 1L) + y <- as.array(y) + y <- tf$convert_to_tensor(y, dtype = x$dtype) + } + NextMethod() +} From f9d6f4b794a9af758095f6ad56bbbaa8a034d428 Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Mon, 31 Mar 2025 10:16:16 -0400 Subject: [PATCH 3/3] fix help handler --- R/help.R | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/R/help.R b/R/help.R index c222569..f052cfe 100644 --- a/R/help.R +++ b/R/help.R @@ -3,17 +3,19 @@ register_tf_help_handler <- function() { reticulate::register_module_help_handler("tensorflow", function(name, subtopic = NULL) { - # get the base tensorflow help url - version <- tf$`__version__` - version <- strsplit(version, ".", fixed = TRUE)[[1]] - help_url <- paste0("https://www.tensorflow.org/versions/r", - version[1], ".", version[2], "/api_docs/python/") - - # upstream TF is missing public docs for later version - # https://github.com/tensorflow/tensorflow/issues/89084 + # # Version specific URLs are disabled because + # # upstream TF is missing public docs for later version + # # https://github.com/tensorflow/tensorflow/issues/89084 + # # get the base tensorflow help url + # version <- tf$`__version__` + # version <- strsplit(version, ".", fixed = TRUE)[[1]] + # help_url <- paste0("https://www.tensorflow.org/versions/r", + # version[1], ".", version[2], "/api_docs/python/") + help_url <- "https://www.tensorflow.org/api_docs/python/" # some adjustments + name <- sub("^tensorflow\\._api\\.v2\\.", "tensorflow.", name) name <- sub("^tensorflow", "tf", name) name <- sub("python.client.session.", "", name, fixed = TRUE) name <- sub("python.ops.", "", name, fixed = TRUE)