Skip to content

Commit

Permalink
Merge pull request #103 from andrjohns/future-stanc3-compat
Browse files Browse the repository at this point in the history
Patch stanfunctions support for future versions of Stanc3
  • Loading branch information
jgabry committed Mar 9, 2023
2 parents 935c6df + 54ef99f commit 0ae1349
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions R/rstan_config.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ rstan_config <- function(pkgdir = ".") {
## path to src/stan_files
## stan_path <- file.path(pkgdir, "src", "stan_files")
# create c++ code
# .stanfunction compatibility only available after 2.29, so need
# to manually wrap function definitions in functions { } before calling stanc
if (grepl("\\.stanfunctions$", file_name) &&
(utils::packageVersion('rstan') < 2.29)) {
if (grepl("\\.stanfunctions$", file_name)) {
mod <- readLines(file_name)
if (!any(grepl("\\bfunctions(\\s*|)\\{", mod))) {
writeLines(c("functions {", mod, "}"), sep = "\n", con = file_name)
Expand Down Expand Up @@ -210,11 +207,17 @@ rstan_config <- function(pkgdir = ".") {

# Replace auto return type with function plain type
for(i in 1:length(decl_lines)) {
cpp_lines[decl_lines[i]] <- .replace_auto(decl_lines[i],cppcode, cpp_lines)
next_decl = ifelse(i == length(decl_lines), length(cpp_lines), decl_lines[i] + 1)
cpp_lines[decl_lines[i]] <- .replace_auto(decl_lines[i], next_decl, cppcode, cpp_lines)
}
}
# The default template parameters emitted by stanc3 can error under some clang versions
cpp_lines <- gsub(">* = 0>", ">* = nullptr>", cpp_lines, fixed = TRUE)
eigen_incl <- ifelse(utils::packageVersion('rstan') >= 2.31,
"#include <stan/math/prim/fun/Eigen.hpp>",
"#include <stan/math/prim/mat/fun/Eigen.hpp>")
cat("#include <exporter.h>",
"#include <stan/math/prim/mat/fun/Eigen.hpp>",
eigen_incl,
"#include <stan/math/prim/meta.hpp>",
file = file.path(pkgdir, "src",
paste(basename(pkgdir), "types.h", sep = "_")),
Expand Down Expand Up @@ -349,26 +352,46 @@ rstan_config <- function(pkgdir = ".") {
}

# Replace auto return type in function exports with the plain type from the main body.
.replace_auto <- function(decl_line, cppcode, cpp_lines) {
.replace_auto <- function(decl_line, next_decl, cppcode, cpp_lines) {
# Extract the name of function
fun_name <- cpp_lines[decl_line]
fun_name <- paste0(cpp_lines[decl_line:next_decl], collapse = " ")
fun_name <- gsub("auto ","",fun_name,fixed=T)
fun_name <- sub("\\(.*","",fun_name,perl=T)

struct_start <- grep(paste0("struct ", fun_name, "_functor"), cpp_lines)
struct_op_start <- grep("operator()", cpp_lines[-(1:struct_start)])[1] + struct_start
# Depending on the version of stanc3, the standalone functions
# with a plain return type can either be wrapped in a struct as a functor,
# or as a separate forward declaration
struct_name <- paste0("struct ", fun_name, "_functor")

rtn_type <- paste0(cpp_lines[struct_start:struct_op_start], collapse = " ")
if (grepl(struct_name, cppcode)) {
struct_start <- grep(struct_name, cpp_lines)
struct_op_start <- grep("operator()", cpp_lines[-(1:struct_start)])[1] + struct_start

rm_operator <- gsub("operator().*", "", rtn_type)
rm_struct_decl <- gsub(".*\\{", "", rm_operator)
repl_dbl <- gsub("T([0-9])*__", "double", rm_struct_decl)
rtn_type <- paste0(cpp_lines[struct_start:struct_op_start], collapse = " ")

rm_operator <- gsub("operator().*", "", rtn_type)
rm_prev <- gsub(".*\\{", "", rm_operator)
} else {
# Find first declaration of function (will be the forward declaration)
first_decl <- grep(fun_name, cpp_lines)[1]

# The return type will be between the function name and the semicolon terminating
# the previous line
last_scolon <- grep(";", cpp_lines[1:first_decl])
last_scolon <- ifelse(last_scolon[length(last_scolon)] == first_decl,
last_scolon[length(last_scolon) - 1],
last_scolon[length(last_scolon)])
rtn_type_full <- paste0(cpp_lines[last_scolon:first_decl], collapse = " ")
rm_fun_name <- gsub(paste0(fun_name, ".*"), "", rtn_type_full)
rm_prev <- gsub(".*;", "", rm_fun_name)

}

repl_dbl <- gsub("T([0-9])*__", "double", rm_prev)

# Extract return type declaration and replace promoted scalar
# type with double
rtn_type <- gsub("template <typename(.*?)> ", "", repl_dbl)

# Update model code with type declarations
gsub("auto ", rtn_type, cpp_lines[decl_line],fixed=T)
gsub("auto", rtn_type, cpp_lines[decl_line], fixed=T)
}

0 comments on commit 0ae1349

Please sign in to comment.