/
getters.R
365 lines (348 loc) · 10.1 KB
/
getters.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#' Get Prior Definitions of a Dynamite Model
#'
#' Extracts the priors used in the dynamite model as a data frame. You
#' can then alter the priors by changing the contents of the `prior` column and
#' supplying this data frame to `dynamite` function using the argument
#' `priors`. See vignettes for details.
#'
#' @note Only the `prior` column of the output should be altered when defining
#' the user-defined priors for the `dynamite`.
#'
#' @export
#' @family fitting
#' @rdname get_priors
#' @param x \[`dynamiteformula` or `dynamitefit`]\cr The model formula or an
#' existing `dynamitefit` object. See [dynamiteformula()] and [dynamite()].
#' @inheritParams dynamite
#' @param ... Ignored.
#' @return A `data.frame` containing the prior definitions.
#' @srrstats {BS5.2} Provides access to the prior definitions of the model.
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' d <- data.frame(y = rnorm(10), x = 1:10, time = 1:10, id = 1)
#' get_priors(obs(y ~ x, family = "gaussian"),
#' data = d, time = "time", group = "id"
#' )
#'
get_priors <- function(x, ...) {
UseMethod("get_priors", x)
}
#' @rdname get_priors
#' @export
get_priors.dynamiteformula <- function(x, data, time, group = NULL, ...) {
out <- dynamite(
dformula = x,
data = data,
time = time,
group = group,
debug = list(no_compile = TRUE),
...
)
out$priors
}
#' @rdname get_priors
#' @export
get_priors.dynamitefit <- function(x, ...) {
x$priors
}
#' Extract the Stan Code of the Dynamite Model
#'
#' Returns the Stan code of the model. Mostly useful for debugging or for
#' building a customized version of the model.
#'
#' @export
#' @family output
#' @rdname get_code
#' @inheritParams get_priors.dynamiteformula
#' @param blocks \[`character()`]\cr Stan block names to extract. If `NULL`,
#' extracts the full model code.
#' @return The Stan model blocks as a `character` string.
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' d <- data.frame(y = rnorm(10), x = 1:10, time = 1:10, id = 1)
#' cat(get_code(obs(y ~ x, family = "gaussian"),
#' data = d, time = "time", group = "id"
#' ))
#' # same as
#' cat(dynamite(obs(y ~ x, family = "gaussian"),
#' data = d, time = "time", group = "id",
#' debug = list(model_code = TRUE, no_compile = TRUE)
#' )$model_code)
#'
get_code <- function(x, ...) {
UseMethod("get_code", x)
}
#' @rdname get_code
#' @export
get_code.dynamiteformula <- function(x, data, time,
group = NULL, blocks = NULL, ...) {
out <- dynamite(
dformula = x,
data = data,
time = time,
group = group,
debug = list(no_compile = TRUE, model_code = TRUE),
...
)
get_code_(out$model_code, blocks)
}
#' @rdname get_code
#' @export
get_code.dynamitefit <- function(x, blocks = NULL, ...) {
if (is.null(x$stanfit)) {
out <- dynamite(
dformula = eval(formula(x)),
data = x$data,
time = x$time_var,
group = x$group_var,
debug = list(no_compile = TRUE, model_code = TRUE),
verbose = FALSE,
...
)$model_code
} else {
out <- x$stanfit@stanmodel@model_code[1L]
}
get_code_(out, blocks)
}
#' Internal Stan Code Block Extraction
#'
#' @param x \[`character(1L)`]\cr The Stan model code string.
#' @param blocks \[`character`]\cr Stan block names to extract. If `NULL`,
#' extracts the full model code.
#' @noRd
get_code_ <- function(x, blocks = NULL) {
if (is.null(blocks)) {
return(x)
}
stopifnot_(
checkmate::test_character(blocks, null.ok = TRUE),
"Argument {.arg blocks} must be a {.cls character} vector or NULL."
)
block_names <- c(
"data",
"transformed data",
"parameters",
"transformed parameters",
"model",
"generated quantities"
)
invalid_blocks <- !blocks %in% block_names
stopifnot_(
all(!invalid_blocks),
c(
"Invalid Stan blocks provided: {cs(blocks[invalid_blocks])}",
`i` = "Argument {.arg blocks} must be NULL or a subset of
{cs(paste0(\"'\", block_names, \"'\"))}."
)
)
x <- strsplit(x, "\n")[[1L]]
block_rows <- paste0(block_names, " {")
block_start <- which(x %in% block_rows)
block_end <- c(block_start[-1L] - 1L, length(x))
names(block_start) <- names(block_end) <- block_names
out <- ""
for (block in blocks) {
out <- c(
out,
x[block_start[block]:block_end[block]]
)
}
paste_rows(out, .parse = FALSE)
}
#' Extract the Model Data of the Dynamite Model
#'
#' Returns the input data to the Stan model. Mostly useful for debugging.
#'
#' @export
#' @family output
#' @rdname get_data
#' @inheritParams get_priors.dynamiteformula
#' @return A `list` containing the input data to Stan.
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' d <- data.frame(y = rnorm(10), x = 1:10, time = 1:10, id = 1)
#' str(get_data(obs(y ~ x, family = "gaussian"),
#' data = d, time = "time", group = "id"
#' ))
#'
get_data <- function(x, ...) {
UseMethod("get_data", x)
}
#' @rdname get_data
#' @export
get_data.dynamiteformula <- function(x, data, time, group = NULL, ...) {
out <- dynamite(
dformula = x,
data = data,
time = time,
group = group,
debug = list(no_compile = TRUE, stan_input = TRUE, model_code = FALSE),
...
)
out$stan_input$sampling_vars
}
#' @rdname get_data
#' @export
get_data.dynamitefit <- function(x, ...) {
out <- dynamite(
dformula = eval(formula(x)),
data = x$data,
time = x$time_var,
group = x$group_var,
debug = list(no_compile = TRUE, stan_input = TRUE, model_code = FALSE),
verbose = FALSE,
priors = x$priors,
...
)
out$stan_input$sampling_vars
}
#' Get Parameter Dimensions of the Dynamite Model
#'
#' Extracts the names and dimensions of all parameters used in the
#' `dynamite` model. See also [dynamite::get_parameter_types()] and
#' [dynamite::get_parameter_names()]. The returned dimensions match those of
#' the `stanfit` element of the `dynamitefit` object. When applied to
#' `dynamiteformula` objects, the model is compiled and sampled for 1 iteration
#' to get the parameter dimensions.
#'
#' @rdname get_parameter_dims
#' @inheritParams get_priors.dynamiteformula
#' @return A named list with all parameter dimensions of the input model.
#' @export
#' @family output
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' get_parameter_dims(multichannel_example_fit)
#'
get_parameter_dims <- function(x, ...) {
UseMethod("get_parameter_dims", x)
}
#' @rdname get_parameter_dims
#' @export
get_parameter_dims.dynamiteformula <- function(x, data, time,
group = NULL, ...) {
out <- try(
suppressWarnings(
dynamite(
dformula = x,
data = data,
time = time,
group = group,
algorithm = "Fixed_param",
chains = 1,
iter = 1,
refresh = 0,
backend = "rstan",
verbose_stan = FALSE,
...
)
),
silent = TRUE
)
stopifnot_(
!inherits(out, "try-error"),
c(
"Unable to determine parameter dimensions:",
`x` = attr(out, "condition")$message
)
)
get_parameter_dims(out)
}
#' @rdname get_parameter_dims
#' @export
get_parameter_dims.dynamitefit <- function(x, ...) {
stopifnot_(
!is.null(x$stanfit),
"No Stan model fit is available."
)
pars_text <- get_code(x, blocks = "parameters")
pars <- get_parameters(pars_text)
out <- rstan::get_inits(x$stanfit)[[1L]]
out <- out[names(out) %in% pars]
lapply(
out,
function(y) {
d <- dim(y)
ifelse_(is.null(d), 1L, d)
}
)
}
#' Internal Parameter Block Variable Name Extraction
#'
#' @param x \[`character(1L)`]\cr The Stan model code string of the
#' "Parameters" block.
#' @noRd
get_parameters <- function(x) {
x <- strsplit(x, split = "\n")[[1L]]
x <- x[grepl(";", x)]
par_regex <- regexec(
pattern = "^.+\\s([^\\s]+);.*$",
text = x,
perl = TRUE
)
par_matches <- regmatches(x, par_regex)
vapply(par_matches, "[[", character(1L), 2L)
}
#' Get Parameter Types of the Dynamite Model
#'
#' Extracts all parameter types of used in the `dynamitefit` object. See
#' [dynamite::as.data.frame.dynamitefit()] for explanations of different types.
#'
#' @param x \[`dynamitefit`]\cr The model fit object.
#' @param ... Ignored.
#' @return A `character` vector with all parameter types of the input model.
#' @export
#' @family output
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' get_parameter_types(multichannel_example_fit)
#'
get_parameter_types <- function(x, ...) {
UseMethod("get_parameter_types", x)
}
#' @rdname get_parameter_types
#' @export
get_parameter_types.dynamitefit <- function(x, ...) {
types <- c(
"alpha", "beta", "delta", "tau", "tau_alpha", "xi",
"sigma_nu", "sigma", "phi", "nu", "lambda", "sigma_lambda",
"psi", "tau_psi", "corr", "corr_psi", "corr_nu",
"omega", "omega_alpha", "omega_psi"
)
d <- as.data.table(x, types = types)
unique(d$type)
}
#' Get Parameter Names of the Dynamite Model
#'
#' Extracts all parameter names of used in the `dynamitefit` object.
#'
#' The naming of parameters generally follows style where the name starts with
#' the parameter type (e.g. beta for time-invariant regression coefficient),
#' followed by underscore and the name of the response variable, and in case of
#' time-invariant, time-varying or random effect, the name of the predictor. An
#' exception to this is spline coefficients omega, which also contain the
#' number denoting the knot number.
#'
#' @param x \[`dynamitefit`]\cr The model fit object.
#' @param types \[`character()`]\cr Extract only names of parameter of a
#' certain type. See [dynamite::get_parameter_types()].
#' @param ... Ignored.
#' @return A `character` vector with parameter names of the input model.
#' @export
#' @family output
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' get_parameter_names(multichannel_example_fit)
#'
get_parameter_names <- function(x, types = NULL, ...) {
UseMethod("get_parameter_names", x)
}
#' @rdname get_parameter_names
#' @export
get_parameter_names.dynamitefit <- function(x, types = NULL, ...) {
if (is.null(types)) {
types <- get_parameter_types(x)
}
d <- as.data.table(x, types = types)
unique(d$parameter)
}