-
Notifications
You must be signed in to change notification settings - Fork 2
/
SEIR_outbreak_model_fitting.R
138 lines (125 loc) · 5.14 KB
/
SEIR_outbreak_model_fitting.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#' list of parameterisations for priors
#' @export
prior_list <- list(gamma_mean = 0.125,gamma_sd = 0.0125,
sigma_mean = 0.2,sigma_sd = 0.025,
S0_mean = 0.9, S0_sd = 0.01,
r0_mean=3.0,r0_sd=1.0,
zeta_mean=0.1,zeta_sd=0.1,
tau_mean = 0.1,
phi = c(1, 1))
#' SEIR model fitting method to a multiple outbreak model
#' @description
#' Create an instance of the hierarchical SEIR Stan model incorporating
#' various data elements and sample model.
#' @param stan_model [rstan] model object
#' @param tmax Total number of time-points in observation
#' @param n_outbreaks Total number of outbreaks
#' @param outbreak_cases Number of daily reported cases by outbreak
#' @param outbreak_sizes The total size of each facility (initial number of susceptible and exposed)
#' @param intervention_switch Describes whether interventions occur in data (default TRUE)
#' @param multilevel_intervention Describes whether intervention occurs
#' @param prior_list List of priors. See [prior_list]
#' @param chains Number of chains to sample
#' @param iter number of iterations of MCMC
#' @param seed The seed for random number generation. Set to replicate results.
#' @param fit_type string "NUTS" or "VB" (VB quicker but less accurate).
#' @param data_model string "poisson" or "negative_binomial". If negative_binomial
#' selected then uses phi prior to control for overdispersion
#'
#' @examples
#' stan_mod <- rstan::stan_model(system.file("stan",
#' "hierarchical_SEIR_incidence_model.stan", package = "cr0eso"))
#' tmax <- 5
#' pop_size <- 100
#' dim(pop_size) <- c(1)
#' example_incidence <- matrix(c(1,1,2,3,2),ncol=1)
#' fit <- seir_model_fit(stan_model = stan_mod, tmax,1,example_incidence,pop_size)
#' @author Mike Irvine
#' @return An object of class `stanfit` returned by \link[rstan]{sampling}
#' @export
seir_model_fit <- function(
stan_model = NULL,
tmax,
n_outbreaks,
outbreak_cases,
outbreak_sizes,
intervention_switch = TRUE,
multilevel_intervention = FALSE,
priors = prior_list,
chains = 4,
iter=600,
seed = 42,
fit_type = "NUTS",
data_model="poisson"
){
if(is.null(stan_model)){
stop("Need to include a stan model. Check examples")
}
# parameter checks
if(!fit_type %in% c("NUTS","VB")){
stop(glue::glue("{fit_type} should be NUTS or VB."))
}
if(!data_model %in% c("poisson","negative_binomial")){
stop(glue::glue("{data_model} should be poisson or negative_binomial."))
}
data_model_code <- dplyr::case_when(data_model == "poisson" ~ 0,
data_model == "negative_binomial" ~ 1)
# define STAN data
stan_data = list(n_obs = tmax,
n_outbreaks = n_outbreaks,
n_difeq = 5,
n_fake = tmax,
y = outbreak_cases,
intervention_switch = intervention_switch,
multilevel_intervention = multilevel_intervention,
independent_r0 = FALSE,
independent_zeta = FALSE,
data_model = data_model_code,
n_prior_mean = outbreak_sizes,
tau_prior_mean = priors$tau_mean,
t0 = 0,
tn = tmax,
ts = c(1:tmax),
fake_ts = c(1:tmax),
# define priors
gamma_mean = priors$gamma_mean,
gamma_sd = priors$gamma_sd,
sigma_mean = priors$sigma_mean,
sigma_sd = priors$sigma_sd,
S0_mean = priors$S0_mean,
S0_sd = priors$S0_sd,
r0_mean = priors$r0_mean,
r0_sd = priors$r0_sd,
zeta_mean = priors$zeta_mean,
zeta_sd = priors$zeta_sd,
phi_prior = priors$phi
)
# Which parameters to monitor in the model:
params_monitor = c("y0", "params","r0k" , "r0" , "r0_sigma", "zetak",
"incidence" , "hyper_priors","counterfactual_cases",
"pp_cases","predictive_r0","phi")
# Fit and sample from the posterior
if(fit_type=="NUTS"){
mod <- rstan::sampling(stan_model,
data = stan_data,
pars = params_monitor,
seed = seed,
chains = chains,
warmup = floor(iter/2),
iter = iter,
control = list(adapt_delta = 0.95))
}else{
cat("Sampling with the VB algorithm.\n")
mod <- rstan::vb(
stanmodels$hierarchical_SEIR_incidence_model,
data = stan_data,
iter = iter,
seed = seed,
pars = params_monitor,
algorithm = "fullrank",
tol_rel_obj = 0.0001,
importance_resampling = TRUE
)
}
list(model = mod)
}