Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch stanfunctions support for future versions of Stanc3 #103

Merged
merged 1 commit into from
Mar 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 39 additions & 16 deletions R/rstan_config.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ rstan_config <- function(pkgdir = ".") {
## path to src/stan_files
## stan_path <- file.path(pkgdir, "src", "stan_files")
# create c++ code
# .stanfunction compatibility only available after 2.29, so need
# to manually wrap function definitions in functions { } before calling stanc
if (grepl("\\.stanfunctions$", file_name) &&
(utils::packageVersion('rstan') < 2.29)) {
if (grepl("\\.stanfunctions$", file_name)) {
mod <- readLines(file_name)
if (!any(grepl("\\bfunctions(\\s*|)\\{", mod))) {
writeLines(c("functions {", mod, "}"), sep = "\n", con = file_name)
Expand Down Expand Up @@ -210,11 +207,17 @@ rstan_config <- function(pkgdir = ".") {

# Replace auto return type with function plain type
for(i in 1:length(decl_lines)) {
cpp_lines[decl_lines[i]] <- .replace_auto(decl_lines[i],cppcode, cpp_lines)
next_decl = ifelse(i == length(decl_lines), length(cpp_lines), decl_lines[i] + 1)
cpp_lines[decl_lines[i]] <- .replace_auto(decl_lines[i], next_decl, cppcode, cpp_lines)
}
}
# The default template parameters emitted by stanc3 can error under some clang versions
cpp_lines <- gsub(">* = 0>", ">* = nullptr>", cpp_lines, fixed = TRUE)
eigen_incl <- ifelse(utils::packageVersion('rstan') >= 2.31,
"#include <stan/math/prim/fun/Eigen.hpp>",
"#include <stan/math/prim/mat/fun/Eigen.hpp>")
cat("#include <exporter.h>",
"#include <stan/math/prim/mat/fun/Eigen.hpp>",
eigen_incl,
"#include <stan/math/prim/meta.hpp>",
file = file.path(pkgdir, "src",
paste(basename(pkgdir), "types.h", sep = "_")),
Expand Down Expand Up @@ -349,26 +352,46 @@ rstan_config <- function(pkgdir = ".") {
}

# Replace auto return type in function exports with the plain type from the main body.
.replace_auto <- function(decl_line, cppcode, cpp_lines) {
.replace_auto <- function(decl_line, next_decl, cppcode, cpp_lines) {
# Extract the name of function
fun_name <- cpp_lines[decl_line]
fun_name <- paste0(cpp_lines[decl_line:next_decl], collapse = " ")
fun_name <- gsub("auto ","",fun_name,fixed=T)
fun_name <- sub("\\(.*","",fun_name,perl=T)

struct_start <- grep(paste0("struct ", fun_name, "_functor"), cpp_lines)
struct_op_start <- grep("operator()", cpp_lines[-(1:struct_start)])[1] + struct_start
# Depending on the version of stanc3, the standalone functions
# with a plain return type can either be wrapped in a struct as a functor,
# or as a separate forward declaration
struct_name <- paste0("struct ", fun_name, "_functor")

rtn_type <- paste0(cpp_lines[struct_start:struct_op_start], collapse = " ")
if (grepl(struct_name, cppcode)) {
struct_start <- grep(struct_name, cpp_lines)
struct_op_start <- grep("operator()", cpp_lines[-(1:struct_start)])[1] + struct_start

rm_operator <- gsub("operator().*", "", rtn_type)
rm_struct_decl <- gsub(".*\\{", "", rm_operator)
repl_dbl <- gsub("T([0-9])*__", "double", rm_struct_decl)
rtn_type <- paste0(cpp_lines[struct_start:struct_op_start], collapse = " ")

rm_operator <- gsub("operator().*", "", rtn_type)
rm_prev <- gsub(".*\\{", "", rm_operator)
} else {
# Find first declaration of function (will be the forward declaration)
first_decl <- grep(fun_name, cpp_lines)[1]

# The return type will be between the function name and the semicolon terminating
# the previous line
last_scolon <- grep(";", cpp_lines[1:first_decl])
last_scolon <- ifelse(last_scolon[length(last_scolon)] == first_decl,
last_scolon[length(last_scolon) - 1],
last_scolon[length(last_scolon)])
rtn_type_full <- paste0(cpp_lines[last_scolon:first_decl], collapse = " ")
rm_fun_name <- gsub(paste0(fun_name, ".*"), "", rtn_type_full)
rm_prev <- gsub(".*;", "", rm_fun_name)

}

repl_dbl <- gsub("T([0-9])*__", "double", rm_prev)

# Extract return type declaration and replace promoted scalar
# type with double
rtn_type <- gsub("template <typename(.*?)> ", "", repl_dbl)

# Update model code with type declarations
gsub("auto ", rtn_type, cpp_lines[decl_line],fixed=T)
gsub("auto", rtn_type, cpp_lines[decl_line], fixed=T)
}