Skip to content

Commit

Permalink
simplifies serialization()
Browse files Browse the repository at this point in the history
  • Loading branch information
shikokuchuo committed Jun 28, 2024
1 parent d4b78ff commit 80514d3
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 188 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mirai
Type: Package
Title: Minimalist Async Evaluation Framework for R
Version: 1.1.0.9008
Version: 1.1.0.9009
Description: High-performance parallel code execution and distributed computing.
Designed for simplicity, a 'mirai' evaluates an R expression asynchronously,
on local or network resources, resolving automatically upon completion.
Expand Down
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# mirai 1.1.0.9008 (development)
# mirai 1.1.0.9009 (development)

* Ephemeral daemons now exit as soon as permissible, eiliminating the 2s linger period.
* `serialization()` function signature simplified for clarity and ease of use.
* `dispatcher()` argument 'retry' now defaults to FALSE for consistency with non-dispatcher behaviour.
* `remote_config()` gains argument 'quote' to control whether or not to quote the daemon launch commmand, and now works with Slurm (thanks @michaelmayer2 #119).
* Ephemeral daemons now exit as soon as permissible, eiliminating the 2s linger period.
* Requires `nanonext` >= 1.1.1.

# mirai 1.1.0
Expand Down
96 changes: 55 additions & 41 deletions R/daemons.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ daemons <- function(n, url = NULL, remote = NULL, dispatcher = TRUE, ...,
`[[<-`(.., .compute, `[[<-`(`[[<-`(envir, "sock", sock), "n", n))
if (length(remote))
launch_remote(url = envir[["urls"]], remote = remote, tls = envir[["tls"]], ..., .compute = .compute)
serialization_refhook()
check_register_everywhere()
} else {
daemons(n = 0L, .compute = .compute)
return(daemons(n = n, url = url, remote = remote, dispatcher = dispatcher, ..., seed = seed, tls = tls, pass = pass, .compute = .compute))
Expand Down Expand Up @@ -359,7 +359,7 @@ daemons <- function(n, url = NULL, remote = NULL, dispatcher = TRUE, ...,
`[[<-`(envir, "urls", urld)
}
`[[<-`(.., .compute, `[[<-`(`[[<-`(envir, "sock", sock), "n", n))
serialization_refhook()
check_register_everywhere()
} else {
daemons(n = 0L, .compute = .compute)
return(daemons(n = n, url = url, remote = remote, dispatcher = dispatcher, ..., seed = seed, tls = tls, pass = pass, .compute = .compute))
Expand Down Expand Up @@ -487,58 +487,70 @@ status <- function(.compute = "default") {
#' Custom Serialization Functions
#'
#' Registers custom serialization and unserialization functions for sending and
#' receiving external pointer reference objects.
#'
#' @param refhook \strong{either} a list or pairlist of two functions: the
#' signature for the first must accept a reference object inheriting from
#' 'class' (or a list of such objects) and return a raw vector, and the
#' second must accept a raw vector and return reference objects (or a list
#' of such objects), \cr \strong{or else} NULL to reset.
#' @param class [default ""] a character string representing the class of object
#' that these serialization function will be applied to, e.g. 'ArrowTabular'
#' or 'torch_tensor'.
#' @param vec [default FALSE] the serialization functions accept and return
#' reference object individually e.g. \code{arrow::write_to_raw} and
#' \code{arrow::read_ipc_stream}. If TRUE, the serialization functions are
#' vectorized and accept and return a list of reference objects, e.g.
#' receiving reference objects.
#'
#' @param class the class of reference object (as a character string) that these
#' functions are applied to, e.g. 'ArrowTabular' or 'torch_tensor',
#' \strong{or else} NULL to cancel registered functions.
#' @param sfunc serialization function: must accept a reference object (or list
#' of objects) inheriting from \sQuote{class} and return a raw vector.
#' @param ufunc unserialization function: must accept a raw vector and return
#' a reference object (or list of reference objects).
#' @param vec [default FALSE] if FALSE the functions must accept and return
#' reference objects individually e.g. \code{arrow::write_to_raw} and
#' \code{arrow::read_ipc_stream}. If TRUE, the functions are vectorized and
#' must accept and return a list of reference objects, e.g.
#' \code{torch::torch_serialize} and \code{torch::torch_load}.
#'
#' @return Invisibly, the pairlist of currently-registered 'refhook' functions.
#' A message is printed to the console when functions are successfully
#' @return Invisibly, a list comprising the currently-registered values for
#' 'class', 'sfunc', 'ufunc' and 'vec', or else NULL if unregistered. A
#' message is printed to the console when functions are successfully
#' registered or reset.
#'
#' @details Calling without any arguments returns the pairlist of
#' currently-registered 'refhook' functions.
#' @details Registering new functions replaces any existing registered
#' functions.
#'
#' To cancel registered functions, specify 'class' as NULL, without the
#' need to supply 'sfunc' or 'ufunc'.
#'
#' Calling without any arguments returns the pairlist of
#' currently-registered serialization functions.
#'
#' This function may be called prior to or after setting daemons, with the
#' registered functions applying across all compute profiles.
#'
#' @examples
#' r <- serialization(list(function(x) serialize(x, NULL), unserialize))
#' print(serialization())
#' serialization(r)
#' reg <- serialization(
#' class = "",
#' sfunc = function(x) serialize(x, NULL),
#' ufunc = base::unserialize
#' )
#' reg
#'
#' serialization(NULL)
#' print(serialization())
#'
#' @export
#'
serialization <- function(refhook = list(), class = "", vec = FALSE) {
serialization <- function(class, sfunc, ufunc, vec = FALSE) {

register <- !missing(refhook)
cfg <- next_config(refhook = refhook, class = class, vec = vec)
missing(class) && return(.[["serial"]])

if (register) {
if (is.list(refhook) && length(refhook) == 2L && is.function(refhook[[1L]]) && is.function(refhook[[2L]]))
cat("mirai serialization functions registered\n", file = stderr()) else
if (is.null(refhook))
cat("mirai serialization functions cancelled\n", file = stderr()) else
stop(._[["refhook_invalid"]])
`[[<-`(., "refhook", list(refhook, class, vec))
register_everywhere(refhook = refhook, class = class, vec = vec)
if (is.null(class)) {
serial <- NULL
next_config(NULL)
cat("mirai serialization functions cancelled\n", file = stderr())
} else if (is.character(class) && is.function(sfunc) && is.function(ufunc)) {
serial <- list(class, sfunc, ufunc, vec)
next_config(refhook = list(sfunc, ufunc), class = class, vec = vec)
cat("mirai serialization functions registered\n", file = stderr())
} else {
stop(._[["serial_invalid"]])
}

invisible(cfg)
`[[<-`(., "serial", serial)
register_everywhere(serial = serial)
invisible(serial)

}

Expand Down Expand Up @@ -658,13 +670,15 @@ query_status <- function(envir) {
dimnames = list(envir[["urls"]], c("i", "online", "instance", "assigned", "complete"))))
}

register_everywhere <- function(refhook, class, vec)
register_everywhere <- function(serial)
for (.compute in names(..))
everywhere(mirai::serialization(refhook = refhook, class = class, vec = vec),
refhook = refhook, class = class, vec = vec, .compute = .compute)
everywhere(
mirai::serialization(class = serial[[1L]], sfunc = serial[[2L]], ufunc = serial[[3L]], vec = serial[[4L]]),
.args = list(serial = serial),
.compute = .compute
)

serialization_refhook <- function(refhook = .[["refhook"]])
if (length(refhook[[1L]]))
register_everywhere(refhook = refhook[[1L]], class = refhook[[2L]], vec = refhook[[3L]])
check_register_everywhere <- function(serial = .[["serial"]])
if (length(serial[[1L]])) register_everywhere(serial = serial)

._scm_. <- as.raw(c(0x07, 0x00, 0x00, 0x00, 0x42, 0x0a, 0x03, 0x00, 0x00, 0x00, 0x02, 0x03, 0x04, 0x00, 0x00, 0x05, 0x03, 0x00, 0x05, 0x00, 0x00, 0x00, 0x55, 0x54, 0x46, 0x2d, 0x38, 0xfc, 0x00, 0x00, 0x00))
2 changes: 1 addition & 1 deletion R/mirai-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
register_cluster = "this function requires a more recent version of R",
requires_daemons = "launching one local daemon as none previously set",
requires_local = "SSH tunnelling requires 'url' hostname to be '127.0.0.1' or 'localhost'",
refhook_invalid = "'refhook' must be a list of 2 functions or NULL",
serial_invalid = "'class' must be a character value or NULL, 'sfunc' and 'ufunc' must be functions",
single_url = "only one 'url' should be specified",
sync_timeout = "initial sync with dispatcher/daemon timed out after 10s",
url_spec = "numeric value for 'url' is out of bounds",
Expand Down
56 changes: 33 additions & 23 deletions man/serialization.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions tests/tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ nanotesterr(launch_remote("ws://[::1]:5555", remote = remote_config(command = "e
nanotesterr(launch_remote(c("tcp://localhost:5555", "tcp://localhost:6666", "tcp://localhost:7777"), remote = remote_config(command = "echo", args = list(c("test", "."), c("test", ".")))), "must be of length 1 or the same length")
nanotesterr(launch_local(1L), "requires daemons to be set")
nanotestn(everywhere(mirai::serialization()))
nanotest(is.list(serialization()))
nanotesterr(serialization(list(NULL)), "must be a list of 2 functions or NULL")
nanotestn(serialization())
nanotesterr(serialization(list(NULL)), "must be a character value or NULL")
nanotest(is.character(host_url()))
nanotest(substr(host_url(ws = TRUE, tls = TRUE), 1L, 3L) == "wss")
nanotest(substr(host_url(tls = TRUE), 1L, 3L) == "tls")
Expand Down Expand Up @@ -240,8 +240,8 @@ connection && .Platform[["OS.type"]] != "windows" && Sys.getenv("NOT_CRAN") == "
nanotestz(sum(tstatus[, "assigned"]))
nanotestz(sum(tstatus[, "complete"]))
nanotestz(daemons(0))
nanotest(is.list(serialization(list(function(x) serialize(x, NULL), unserialize))))
nanotest(is.function(serialization()[[1L]]))
nanotest(is.list(serialization(class = "", sfunc = function(x) serialize(x, NULL), ufunc = unserialize)))
nanotest(is.function(serialization()[[2L]]))
nanotesto(daemons(url = "wss://127.0.0.1:0", token = TRUE, pass = "test"))
nanotestn(launch_local(1L))
Sys.sleep(1L)
Expand All @@ -253,7 +253,7 @@ connection && .Platform[["OS.type"]] != "windows" && Sys.getenv("NOT_CRAN") == "
nanotestn(saisei(1))
nanotesterr(launch_local(0:1), "out of bounds")
nanotesterr(launch_remote(1:2), "out of bounds")
nanotestn(unlist(serialization(NULL)))
nanotestn(serialization(NULL))
option <- 15L
nanotesto(daemons(1, dispatcher = TRUE, maxtasks = 10L, timerstart = 1L, walltime = 1000L, seed = 1546, token = TRUE, cleanup = option, autoexit = tools::SIGCONT))
Sys.sleep(1L)
Expand Down
11 changes: 6 additions & 5 deletions vignettes/databases.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ everywhere({
con <<- dbConnect(adbi::adbi("adbcsqlite"), uri = ":memory:")
})
serialization(
refhook = list(arrow::write_to_raw,
function(x) arrow::read_ipc_stream(x, as_data_frame = FALSE)),
class = "nanoarrow_array_stream"
class = "nanoarrow_array_stream",
sfunc = arrow::write_to_raw,
ufunc = function(x) arrow::read_ipc_stream(x, as_data_frame = FALSE)
)
```
`mirai()` calls may then be used to write to or query the database all in the Arrow format.
Expand Down Expand Up @@ -249,8 +249,9 @@ server <- function(input, output, session) {

# serialization() specifies the native Arrow serialization functions
serialization(
refhook = list(arrow::write_to_raw, nanoarrow::read_nanoarrow),
class = "nanoarrow_array_stream"
class = "nanoarrow_array_stream",
sfunc = arrow::write_to_raw,
ufunc = nanoarrow::read_nanoarrow
)

# run Shiny app
Expand Down
11 changes: 6 additions & 5 deletions vignettes/databases.Rmd.orig
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ everywhere({
con <<- dbConnect(adbi::adbi("adbcsqlite"), uri = ":memory:")
})
serialization(
refhook = list(arrow::write_to_raw,
function(x) arrow::read_ipc_stream(x, as_data_frame = FALSE)),
class = "nanoarrow_array_stream"
class = "nanoarrow_array_stream",
sfunc = arrow::write_to_raw,
ufunc = function(x) arrow::read_ipc_stream(x, as_data_frame = FALSE)
)
```
`mirai()` calls may then be used to write to or query the database all in the Arrow format.
Expand Down Expand Up @@ -182,8 +182,9 @@ server <- function(input, output, session) {

# serialization() specifies the native Arrow serialization functions
serialization(
refhook = list(arrow::write_to_raw, nanoarrow::read_nanoarrow),
class = "nanoarrow_array_stream"
class = "nanoarrow_array_stream",
sfunc = arrow::write_to_raw,
ufunc = nanoarrow::read_nanoarrow
)

# run Shiny app
Expand Down
Loading

0 comments on commit 80514d3

Please sign in to comment.