-
Notifications
You must be signed in to change notification settings - Fork 75
/
cluster-call.R
59 lines (53 loc) · 1.76 KB
/
cluster-call.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#' Call a function on each node of a cluster
#'
#' `cluster_call()` executes the code on each worker and returns the results;
#' `cluster_send()` executes the code ignoring the result. Jobs are submitted
#' to workers in parallel, and then we wait until they're complete.
#'
#' @param cluster A cluster.
#' @param code An expression to execute on each worker.
#' @param ptype Determines the output type. The default returns a list,
#' which will always succeed. Set to a narrower type to simplify the output.
#' @export
#' @examples
#' cl <- default_cluster()
#'
#' # Run code on each cluster and retrieve results
#' cluster_call(cl, Sys.getpid())
#' cluster_call(cl, runif(1))
#'
#' # use ptype to simplify
#' cluster_call(cl, runif(1), ptype = double())
#'
#' # use cluster_send() to ignore results
#' cluster_send(cl, x <- runif(1))
#' cluster_call(cl, x, ptype = double())
cluster_call <- function(cluster, code, ptype = list()) {
stopifnot(is_cluster(cluster))
code <- enexpr(code)
to_rm <- attr(cluster, "cleaner")$reset()
f <- function(code, to_rm) {
rm(list = to_rm, envir = globalenv())
eval(code, globalenv())
}
lapply(cluster, function(x) x$call(f, list(code = code, to_rm = to_rm)))
lapply(cluster, function(x) x$poll_process(-1))
results <- lapply(cluster, function(x) x$read())
errs <- lapply(results, "[[", "error")
failed <- !vapply(errs, is.null, logical(1))
if (any(failed)) {
err <- errs[failed][[1]]
abort("Computation failed", parent = err)
}
out <- lapply(results, "[[", "result")
out <- vctrs::vec_cast(out, ptype)
out
}
#' @rdname cluster_call
#' @export
cluster_send <- function(cluster, code) {
stopifnot(is_cluster(cluster))
code <- call2("{", enexpr(code), NULL)
cluster_call(cluster, !!code)
invisible(cluster)
}