Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign up| ## TODO: handle dynamic parameters, i.e. | |
| ## > VdP.sys = ' | |
| ## > dxdt[0] = x[1]; | |
| ## > dxdt[1] = mu * (1 - x[0] * x[0]) * x[1] - x[0]; | |
| ## > ' | |
| ## > compile_sys("vpol2", VdP.sys, "mu", method = "bsd") | |
| ## > vpol2_set_params(mu = 0.5) | |
| setClass("DxdtOdeConstructor", slots = c(x="Basic")) | |
| setClass("DxdtOde", slots = c(x = "Basic", rhs = "Basic")) | |
| setClass("ODESystem", slots = c(odesys = "list", cpp_source = "character", | |
| compile_env = "environment")) | |
| #' Solve System of Ordinary Differential Equations | |
| #' | |
| #' This is a wrapper of the `odeintr` R package using | |
| #' symengine objects to specify the ODE system and C code | |
| #' generation functionality from symengine to generate the | |
| #' C++ source. The `dxdt` function and defined `==` S4 method | |
| #' allow one to intuitively specify the ODE system with symengine | |
| #' objects. The `ODESystem` will generate C++ source | |
| #' and compile on the fly with Rcpp. Then `predict` can be used to get | |
| #' results. | |
| #' | |
| #' @param x A SymEngine Basic object of type Symbol or a R object | |
| #' that will be converted to `Symbol(x)`. | |
| #' @rdname ODESystem | |
| #' @return `dxdt` returns a DxdtOdeConstructor S4 object. | |
| #' @export | |
| #' @md | |
| #' @examples | |
| #' # A differential equation specified with dxdt and == | |
| #' x <- Symbol("x") | |
| #' eq <- dxdt(x) == 1/exp(x) | |
| #' print(eq) | |
| dxdt <- function(x) { | |
| x <- Symbol(x) | |
| new("DxdtOdeConstructor", x = x) | |
| } | |
| #' @param e1 A DxdtOdeConstructor S4 object which can be returned by `dxdt`. | |
| #' @param e2 A Basic object or an R object that will be converted to `S(e2)`. | |
| #' @return S4 method of `==` for "DxdtOdeConstructor" returns a DxdtOde S4 object. | |
| #' @rdname ODESystem | |
| setMethod("==", signature=c(e1 = "DxdtOdeConstructor", e2 = "ANY"), | |
| function(e1, e2) { | |
| x <- e1@x | |
| rhs <- S(e2) | |
| new("DxdtOde", x = x, rhs = rhs) | |
| } | |
| ) | |
| setMethod("show", signature = c(object = "DxdtOde"), | |
| function(object) { | |
| x <- object@x | |
| rhs <- object@rhs | |
| cat("Ordinary differential equation:\n") | |
| x_str <- as.character(x) | |
| rhs_str <- as.character(rhs) | |
| cat(sprintf("d(%s)/dt == %s\n", x_str, rhs_str)) | |
| } | |
| ) | |
| #' @param odesys,... DxdtOde S4 objects that can be returned with | |
| #' `dxdt(x) == rhs`. Or `odesys` can be a list of DxdtOde S4 objects | |
| #' when there is no dot arguments. | |
| #' @param method,atol,rtol Passed to `odeintr::compile_sys`. | |
| #' @param compile Logical, whether to compile the C++ source. Useful if | |
| #' you only want to obtain the code. | |
| #' @return `ODESystem` returns a ODESystem S4 object. | |
| #' @rdname ODESystem | |
| #' @export | |
| #' @examples | |
| #' \donttest{ | |
| #' ## Lorenz system | |
| #' use_vars(x, y, z) | |
| #' sigma <- 10 | |
| #' rho <- 28 | |
| #' beta <- 8/3 | |
| #' lorenz_sys <- ODESystem( | |
| #' dxdt(x) == sigma * (y - x), | |
| #' dxdt(y) == (rho - z) * x - y, | |
| #' dxdt(z) == - beta * z + x * y | |
| #' ) | |
| #' res <- predict( | |
| #' lorenz_sys, init = c(x = 1, y = 1, z = 1), duration = 100, step_size = 0.001 | |
| #' ) | |
| #' plot(res[, c(2, 4)], type = 'l', col = "steelblue", main = "Lorenz Attractor") | |
| #' } | |
| ODESystem <- function(odesys, ..., method = "rk5_i", | |
| atol = 1e-06, rtol = 1e-06, compile = TRUE) { | |
| if (!missing(...)) | |
| odesys <- list(odesys, ...) | |
| if (is(odesys, "DxdtOde")) | |
| odesys <- list(odesys) | |
| ## Some checks | |
| check_odesys <- function(odesys) { | |
| stopifnot(is.list(odesys)) | |
| for (el in odesys) | |
| stopifnot(is(el, "DxdtOde")) | |
| rhs_list <- lapply(odesys, function(x) x@rhs) | |
| x_list <- lapply(odesys, function(x) x@x) | |
| check_symbol <- function(names) { | |
| #if (any(grepl("^_", names))) | |
| # stop(sprintf("variable name [%s] can not start with '_'", | |
| # names[grepl("^_", names)[1]])) | |
| if (any(names == "")) | |
| stop("variable name can not be empty string") | |
| if (any(grepl("(\\[|\\]| )", names))) | |
| stop(sprintf("variable name [%s] can not have space, '[' or ']'", | |
| names[grepl("^_", names)[1]])) | |
| } | |
| check_symbol(unlist(lapply(rhs_list, function(rhs) as.character(free_symbols(rhs))))) | |
| check_symbol(unlist(lapply(x_list, as.character))) | |
| if (length(unique(Vector(x_list))) != length(x_list)) | |
| stop("lhs variables should be unique") | |
| } | |
| check_odesys(odesys) | |
| rhs_list <- lapply(odesys, function(x) x@rhs) | |
| x_list <- lapply(odesys, function(x) x@x) | |
| x_mapping_to <- sprintf("x[%s]", seq(length(x_list)) - 1L) | |
| x_mapping_to <- lapply(x_mapping_to, Symbol) | |
| lhs_mapping_to <- sprintf("dxdt[%s]", seq(length(x_list)) - 1L) | |
| lhs_mapping_to <- lapply(lhs_mapping_to, Symbol) | |
| rhs_mapping_to <- rhs_list | |
| for (i in seq_along(x_list)) | |
| for (j in seq_along(rhs_mapping_to)) | |
| rhs_mapping_to[[j]] <- subs(rhs_mapping_to[[j]], x_list[[i]], x_mapping_to[[i]]) | |
| ccode_str <- sprintf("%s = %s", codegen(Vector(lhs_mapping_to), type="ccode"), | |
| codegen(Vector(rhs_mapping_to), type="ccode")) | |
| compile_sys_env <- new.env() | |
| if (!requireNamespace("odeintr", quietly = TRUE)) | |
| stop("'odeintr' package needs to be installed for this functionality") | |
| rcpp_code <- odeintr::compile_sys( | |
| name = "predict", | |
| sys = ccode_str, | |
| pars = NULL, | |
| const = FALSE, | |
| method = method, | |
| sys_dim = length(x_list), | |
| atol = atol, | |
| rtol = rtol, | |
| compile = FALSE, | |
| observer = NULL, | |
| env = compile_sys_env | |
| ) | |
| # Sanity check | |
| stopifnot(length(ls(compile_sys_env, all.names = TRUE)) == 0L) | |
| if (compile) { | |
| res <- Rcpp::sourceCpp( | |
| code = rcpp_code, env = compile_sys_env, verbose = FALSE) | |
| } | |
| new("ODESystem", odesys = odesys, | |
| cpp_source = rcpp_code, compile_env = compile_sys_env) | |
| } | |
| #' @param object A ODESystem S4 object. | |
| #' @param init A numeric vector specifying the initial conditions. It can | |
| #' be named with the variable names or it can be unnamed but in the | |
| #' same of order of equations. | |
| #' @param duration,step_size,start Passed to the function generated by | |
| #' `odeintr::compile_sys`. | |
| #' @return `predict` returns a dataframe. | |
| #' @rdname ODESystem | |
| #' @export | |
| setMethod("predict", c(object = "ODESystem"), | |
| function(object, init, duration, step_size = 1, start = 0) { | |
| x_list <- lapply(object@odesys, function(x) x@x) | |
| x_names <- vapply(x_list, as.character, character(1)) | |
| stopifnot(length(init) == length(x_list)) | |
| # Reorder 'init' if it is named | |
| if (!is.null(names(init))) { | |
| matching <- match(x_names, names(init)) | |
| if (anyDuplicated(matching) || any(is.na(matching))) | |
| stop("Invalid names specified 'init'") | |
| init <- init[matching] | |
| } | |
| ans <- object@compile_env$predict(init, duration, step_size, start) | |
| stopifnot( | |
| identical(names(ans), c("Time", paste0("X", seq_along(x_list)))) | |
| ) | |
| names(ans)[2:length(ans)] <- x_names | |
| ans | |
| } | |
| ) | |
| setMethod("show", c(object = "ODESystem"), | |
| function(object) { | |
| x_list <- lapply(object@odesys, function(x) x@x) | |
| x_names <- vapply(x_list, as.character, character(1)) | |
| rhs_list <- lapply(object@odesys, function(x) x@rhs) | |
| cat(sprintf('%s with variables %s:\n', class(object), | |
| paste(x_names, collapse = ", "))) | |
| for (i in seq_along(x_list)) { | |
| cat(sprintf(" d(%s)/dt = %s\n", x_names[i], as.character(rhs_list[[i]]))) | |
| } | |
| } | |
| ) |