Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers #321

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ S3method(.subset_draws,draws_df)
S3method(.subset_draws,draws_list)
S3method(.subset_draws,draws_matrix)
S3method(.subset_draws,draws_rvars)
S3method(Complex,rvar)
S3method(Math,rvar)
S3method(Math,rvar_factor)
S3method(Ops,rvar)
Expand Down Expand Up @@ -297,6 +298,7 @@ S3method(thin_draws,draws)
S3method(thin_draws,rvar)
S3method(unique,rvar)
S3method(unique,rvar_factor)
S3method(var,complex)
S3method(var,default)
S3method(var,rvar)
S3method(variables,"NULL")
Expand All @@ -305,6 +307,7 @@ S3method(variables,draws_df)
S3method(variables,draws_list)
S3method(variables,draws_matrix)
S3method(variables,draws_rvars)
S3method(variance,complex)
S3method(variance,draws_array)
S3method(variance,draws_matrix)
S3method(variance,rvar)
Expand All @@ -313,6 +316,7 @@ S3method(vec_cast,character.rvar_factor)
S3method(vec_cast,character.rvar_ordered)
S3method(vec_cast,distribution.rvar)
S3method(vec_cast,rvar.character)
S3method(vec_cast,rvar.complex)
S3method(vec_cast,rvar.distribution)
S3method(vec_cast,rvar.double)
S3method(vec_cast,rvar.factor)
Expand All @@ -323,6 +327,7 @@ S3method(vec_cast,rvar.rvar)
S3method(vec_cast,rvar.rvar_factor)
S3method(vec_cast,rvar.rvar_ordered)
S3method(vec_cast,rvar_factor.character)
S3method(vec_cast,rvar_factor.complex)
S3method(vec_cast,rvar_factor.double)
S3method(vec_cast,rvar_factor.factor)
S3method(vec_cast,rvar_factor.integer)
Expand All @@ -332,6 +337,7 @@ S3method(vec_cast,rvar_factor.rvar)
S3method(vec_cast,rvar_factor.rvar_factor)
S3method(vec_cast,rvar_factor.rvar_ordered)
S3method(vec_cast,rvar_ordered.character)
S3method(vec_cast,rvar_ordered.complex)
S3method(vec_cast,rvar_ordered.double)
S3method(vec_cast,rvar_ordered.factor)
S3method(vec_cast,rvar_ordered.integer)
Expand All @@ -349,6 +355,7 @@ S3method(vec_ptype,rvar_factor)
S3method(vec_ptype,rvar_ordered)
S3method(vec_ptype2,character.rvar_factor)
S3method(vec_ptype2,character.rvar_ordered)
S3method(vec_ptype2,complex.rvar)
S3method(vec_ptype2,distribution.rvar)
S3method(vec_ptype2,double.rvar)
S3method(vec_ptype2,factor.rvar_factor)
Expand All @@ -357,6 +364,7 @@ S3method(vec_ptype2,integer.rvar)
S3method(vec_ptype2,logical.rvar)
S3method(vec_ptype2,ordered.rvar_factor)
S3method(vec_ptype2,ordered.rvar_ordered)
S3method(vec_ptype2,rvar.complex)
S3method(vec_ptype2,rvar.distribution)
S3method(vec_ptype2,rvar.double)
S3method(vec_ptype2,rvar.integer)
Expand Down Expand Up @@ -396,6 +404,7 @@ export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(as_rvar)
export(as_rvar_complex)
export(as_rvar_factor)
export(as_rvar_integer)
export(as_rvar_logical)
Expand Down Expand Up @@ -438,7 +447,10 @@ export(is_draws_list)
export(is_draws_matrix)
export(is_draws_rvars)
export(is_rvar)
export(is_rvar_complex)
export(is_rvar_factor)
export(is_rvar_integer)
export(is_rvar_logical)
export(is_rvar_ordered)
export(iteration_ids)
export(mad)
Expand Down
4 changes: 2 additions & 2 deletions R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ check_draws_object <- function(x) {
#' @noRd
check_variables_are_numeric <- function(
x, to = "draws_array",
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i),
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i) && !is.complex(x_i),
convert = TRUE
) {

Expand Down Expand Up @@ -145,7 +145,7 @@ validate_draws_per_variable <- function(...) {
# '.nchains' is an additional argument in chain supporting formats
stop_no_call("'.nchains' is not supported for this format.")
}
out <- lapply(out, as.numeric)
out <- lapply(out, function(x) if (is.numeric(x) || is.complex(x)) x else as.numeric(x))
ndraws_per_variable <- lengths(out)
ndraws <- max(ndraws_per_variable)
if (!all(ndraws_per_variable %in% c(1, ndraws))) {
Expand Down
31 changes: 27 additions & 4 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ setOldClass(get_rvar_class(ordered(NULL)))

# helpers: validation -----------------------------------------------------------------

# check the given rvar is not complex
check_rvar_not_complex <- function(x, f = NULL) {
if (is_rvar_complex(x)) {
f <- if (is.null(f)) "" else paste0("`", f, "` ")
stop_no_call("Cannot apply ", f, "function to complex rvars.")
}
}

# Check the passed yank index (for x[[...]]) is valid
check_rvar_yank_index = function(x, i, ...) {
index <- dots_list(i, ..., .preserve_empty = TRUE, .ignore_empty = "none")
Expand Down Expand Up @@ -948,12 +956,16 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em
#' by first collapsing dimensions into columns of the draws matrix
#' (so that .f can be a rowXXX() function)
#' @param x an rvar
#' @param name function name to use for error messages
#' @param .name function name to use for error messages
#' @param .f a function that takes a matrix and summarises its rows, like rowMeans
#' @param ... arguments passed to `.f`
#' @param .ordered_okay can this function be applied to rvar_ordereds?
#' @noRd
summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_okay = FALSE) {
if (is_rvar_complex(x)) {
return(summarise_rvar_within_draws(x, match.fun(.name), ...))
}

.length <- length(x)
if (!.length) {
x <- rvar()
Expand All @@ -966,7 +978,7 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o
.draws <- .f(draws_of(as_rvar_numeric(x)), ...)
.draws <- while_preserving_dims(function(.draws) ordered(.levels[round(.draws)], .levels), .draws)
} else if (is_rvar_factor(x)) {
stop_no_call("Cannot apply `", .name, "` function to rvar_factor objects.")
stop_no_call("Cannot apply `rvar_", .name, "` function to rvar_factor objects.")
} else {
.draws <- .f(draws_of(x), ...)
}
Expand Down Expand Up @@ -997,18 +1009,29 @@ summarise_rvar_by_element <- function(x, .f, ...) {
#' by first collapsing dimensions into columns of the draws matrix, applying the
#' function, then restoring dimensions (so that .f can be a colXXX() function)
#' @param x an rvar
#' @param name function name to use for error messages
#' @param .name function name to use for error messages, and also function to
#' be used as a backup for complex numbers
#' @param .f a function that takes a matrix and summarises its columns, like colMeans
#' @param .extra_dim extra dims added by `.f` to the output, e.g. in the case of
#' matrixStats::colRanges this is `2`
#' @param .extra_dimnames extra dimension names for dims added by `.f` to the output
#' @param .ordered_okay can this function be applied to rvar_ordereds?
#' @param .factor_okay can this function be applied to rvar_factors?
#' @param .complex_okay can this function be applied to complex rvars? If not,
#' the function match.fun(.name) will be used instead, element-by-element.
#' @param ... arguments passed to `.f`
#' @noRd
summarise_rvar_by_element_via_matrix <- function(
x, .name, .f, .extra_dim = NULL, .extra_dimnames = NULL, .ordered_okay = TRUE, .factor_okay = FALSE, ...
x, .name, .f,
.extra_dim = NULL, .extra_dimnames = NULL,
.ordered_okay = TRUE, .factor_okay = FALSE,
.complex_okay = FALSE,
...
) {
if (is_rvar_complex(x) && !.complex_okay) {
return(summarise_rvar_by_element(x, match.fun(.name), ...))
}

.dim <- dim(x)
.dimnames <- dimnames(x)
.length <- length(x)
Expand Down
75 changes: 72 additions & 3 deletions R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
#' @details For objects that are already [`rvar`]s, returns them (with modified dimensions
#' if `dim` is not `NULL`).
#'
#' For numeric or logical vectors or arrays, returns an [`rvar`] with a single draw and
#' For [`numeric`], [`complex`], or [`logical`] vectors or arrays, returns an [`rvar`] with a single draw and
#' the same dimensions as `x`. This is in contrast to the [rvar()] constructor, which
#' treats the first dimension of `x` as the draws dimension. As a result, `as_rvar()`
#' is useful for creating constants.
#'
#' While `as_rvar()` attempts to pick the most suitable subtype of [`rvar`] based on the
#' type of `x` (possibly returning an [`rvar_factor`] or [`rvar_ordered`]),
#' `as_rvar_numeric()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
#' the draws of the output [`rvar`] to be [`numeric`], [`integer`], or [`logical`]
#' `as_rvar_numeric()`, `as_rvar_complex()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce
#' the draws of the output [`rvar`] to be [`numeric`], [`complex`], [`integer`], or [`logical`]
#' (respectively), and always return a base [`rvar`], never a subtype.
#'
#' @seealso [rvar()] to construct [`rvar`]s directly. See [rdo()], [rfun()], and
Expand Down Expand Up @@ -87,6 +87,14 @@ as_rvar_numeric <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
out
}

#' @rdname as_rvar
#' @export
as_rvar_complex <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains)
draws_of(out) <- while_preserving_dims(as.complex, draws_of(out))
out
}

#' @rdname as_rvar
#' @export
as_rvar_integer <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
Expand Down Expand Up @@ -121,6 +129,51 @@ is_rvar <- function(x) {
inherits(x, "rvar")
}

#' Is `x` a complex random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`complex`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_complex()] to convert objects to `rvar`s backed by [`complex`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`complex`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_complex <- function(x) {
is.complex(draws_of(x))
}

#' Is `x` an integer random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`integer`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_integer()] to convert objects to `rvar`s backed by [`integer`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`integer`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_integer <- function(x) {
is.integer(draws_of(x))
}

#' Is `x` a logical random variable?
#'
#' Test if `x` is an [`rvar`] backed by [`logical`] draws.
#'
#' @inheritParams is_rvar
#'
#' @seealso [as_rvar_logical()] to convert objects to `rvar`s backed by [`logical`] draws.
#'
#' @return `TRUE` if `x` is an [`rvar`] backed by [`logical`] draws, `FALSE` otherwise.
#'
#' @export
is_rvar_logical <- function(x) {
is.logical(draws_of(x))
}

#' @export
is.matrix.rvar <- function(x) {
length(dim(draws_of(x))) == 3
Expand Down Expand Up @@ -384,6 +437,22 @@ vec_cast.rvar_factor.double <- function(x, to, ...) new_constant_rvar(while_pres
#' @export
vec_cast.rvar_ordered.double <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))

# complex -> rvar
#' @export
vec_ptype2.complex.rvar <- function(x, y, ...) new_rvar()
#' @export
vec_ptype2.rvar.complex <- function(x, y, ...) new_rvar()
#' @export
vec_cast.rvar.complex <- function(x, to, ...) new_constant_rvar(x)

# complex -> rvar_factor
#' @export
vec_cast.rvar_factor.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.factor, x))

# complex -> rvar_ordered
#' @export
vec_cast.rvar_ordered.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x))

# integer -> rvar
#' @export
vec_ptype2.integer.rvar <- function(x, y, ...) new_rvar()
Expand Down
12 changes: 8 additions & 4 deletions R/rvar-dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#' @name rvar-dist
#' @export
density.rvar <- function(x, at, ...) {
check_rvar_not_complex(x, "density")
summarise_rvar_by_element(x, function(draws) {
d <- density(draws, cut = 0, ...)
f <- approxfun(d$x, d$y, yleft = 0, yright = 0)
Expand All @@ -66,6 +67,7 @@ distributional::cdf
#' @rdname rvar-dist
#' @export
cdf.rvar <- function(x, q, ...) {
check_rvar_not_complex(x, "cdf")
summarise_rvar_by_element(x, function(draws) {
ecdf(draws)(q)
})
Expand All @@ -91,13 +93,15 @@ cdf.rvar_ordered <- function(x, q, ...) {
#' @rdname rvar-dist
#' @export
quantile.rvar <- function(x, probs, ...) {
check_rvar_not_complex(x, "quantile")
summarise_rvar_by_element_via_matrix(x,
"quantile",
function(draws) {
t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...))
},
function(..., names) t(matrixStats::colQuantiles(..., useNames = FALSE)),
.extra_dim = length(probs),
.extra_dimnames = list(NULL)
.extra_dimnames = list(NULL),
probs = probs,
names = FALSE,
...
)
}

Expand Down
14 changes: 12 additions & 2 deletions R/rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,27 @@ Math.rvar <- function(x, ...) {
if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) {
# cumulative functions need to be handled differently
# from other functions in this generic
new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x))
if (length(x) > 1) {
draws_of(x) <- t(apply(draws_of(x), 1, f))
}
} else {
new_rvar(f(draws_of(x), ...), .nchains = nchains(x))
draws_of(x) <- f(draws_of(x), ...)
}

x
}

#' @export
Math.rvar_factor <- function(x, ...) {
stop_no_call("Cannot apply `", .Generic, "` function to rvar_factor objects.")
}

#' @export
Complex.rvar <- function(z) {
f <- get(.Generic)
rvar_apply_vec_fun(f, z)
}

# matrix stuff ---------------------------------------------------

#' Matrix multiplication of random variables
Expand Down
Loading
Loading