Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

401 lines (371 sloc) 15.8 KB
## September 2019, Pierre E. Jacob
## This script illustrates how one can use the log pdf and its gradient
## from a stan fit object and implement an algorithm of our choice;
## here a coupled HMC algorithm a la Heng & Jacob 2019 https://arxiv.org/abs/1709.00404
## with an illustration of the TV upper bounds a la Biswas, Jacob & Vanetti 2019 https://arxiv.org/abs/1905.09971
## note: this scripts requires the R packages:
## rstan, mvtnorm, parallel, doParallel, doRNG, ggplot2
## thanks to all who contributed to those
## disclaimer: this is meant to be proof of concept, this code has not been
## thoroughly tested; you're welcome to re-use any of this but at your own risk
## The model and data are taken from https://data.princeton.edu/pop510/hospmle
## and https://data.princeton.edu/pop510/hospStan
## Taken from first link above:
## "We will illustrate random intercept logit models using data from Lillard and Panis (2000)
## on 1060 births to 501 mothers. The outcome of interest is whether the birth was delivered
## in a hospital or elsewhere. The predictors include the log of income loginc, the distance
## to the nearest hospital distance, and two indicators of mothers’s education: dropout for less
## than high school and college for college graduates, so high school or
## some college is the reference cell."
##
set.seed(19)
hosp <- read.table("https://data.princeton.edu/pop510/hospital.dat", header = FALSE)
names(hosp) <- c("hosp","loginc","distance","dropout","college","mother")
## now Stan code taken from https://data.princeton.edu/pop510/hospStan
hosp_code <- '
data {
int N; // number of obs (pregnancies)
int M; // number of groups (women)
int K; // number of predictors
int y[N]; // outcome
row_vector[K] x[N]; // predictors
int g[N]; // map obs to groups (pregnancies to women)
}
parameters {
real alpha;
real a[M];
vector[K] beta;
real<lower=0,upper=10> sigma;
}
model {
alpha ~ normal(0,100);
a ~ normal(0,sigma);
beta ~ normal(0,100);
for(n in 1:N) {
y[n] ~ bernoulli(inv_logit( alpha + a[g[n]] + x[n]*beta));
}
}
'
## prepare the data
hosp_data <- list(N=nrow(hosp),M=501,K=4,y=hosp[,1],x=hosp[,2:5],g=hosp[,6])
## load Stan
library(rstan)
## use all cores but 2
options(mc.cores = parallel::detectCores()-2)
## warm-up and total number of iterations
warmup <- 1000
iter <- 2000
## compile model and run MCMC
hfit <- stan(model_code=hosp_code, model_name="hospitals", data=hosp_data,
warmup = warmup, iter=iter, chains=2, sample_file = "sample_file_hosp")
# trace plot of the two chains
traceplot(hfit,c("alpha","beta[1]","beta[2]","beta[3]","beta[4]","sigma"),
ncol=1,nrow=6,inc_warmup=F)
## OK, so it seems to be working alright, at first glance
## extract samples from stan fit
postsamples <- as.matrix(hfit)
## transform samples so that they are "unconstrained" (i.e. take values in all of R^d)
unconstr_posterior <- postsamples[,1:(dim(postsamples)[2]-1)]
unconstr_posterior[,dim(unconstr_posterior)[2]] <- log(unconstr_posterior[,dim(unconstr_posterior)[2]])
## load inverse of mass matrix, i.e. an approximation of the diagonal elements of the posterior covariance matrix
## this was found on https://discourse.mc-stan.org/t/extracting-the-mass-matrix-euclidean-metric-in-rstan/9897
datadump <- readLines("sample_file_hosp_1.csv") # read the output generated by the call to 'stan'
massmatindex <- grep("^# Diagonal elements", datadump) + 1 # one line after the line "# Diagonal elements of inverse mass matrix:"
yy <- unlist(strsplit(datadump[massmatindex], split=", "))
yy[1] <- substring(yy[1], 3) # eliminate the '#' symbol
invmassdiag <- as.numeric(yy)
## mass matrix and its sqrt, for convenience
massdiag <- 1/invmassdiag
sqrt_massdiag <- sqrt(massdiag)
## get the stepsize for HMC
sampler_params <- get_sampler_params(hfit, inc_warmup = TRUE)
stepsize <- as.numeric(sampler_params[[1]][(warmup+1),2])
## number of leap frog steps post warmup
table(sampler_params[[1]][(warmup+1):iter,4])
## pick number of leap frog step, somewhat arbitrarily
nleapfrogsteps <- 10
## target dimension
target_dim <- get_num_upars(hfit)
## function to evaluate log density on unconstrained space
stan_logtarget <- function(x) log_prob(hfit, x)
## function to evaluate gradient of log density on unconstrained space
stan_gradlogtarget <- function(x) grad_log_prob(hfit, x)
## estimated posterior moments, on the unconstrained space
postcov <- cov(unconstr_posterior)
postmean <- colMeans(unconstr_posterior)
## Now let us define functions for HMC, Metropolis Hastings, an initial distribution,
## and a way of sampling meeting times
## initial distribution of the chains
library(mvtnorm)
rinit <- function(){
chain_state <- mvtnorm::rmvnorm(1, postmean, postcov)
current_pdf <- stan_logtarget(chain_state)
return(list(chain_state = chain_state[1,], current_pdf = current_pdf))
}
## HMC kernel with fixed stepsize and fixed number of leap frog steps
hmc_kernel <- function(state){
# draw momentum
initial_momentum <- rnorm(target_dim, 0, sqrt_massdiag)
chain_state <- state$chain_state
position <- chain_state
# leap frog integrator
momentum <- initial_momentum + stepsize * stan_gradlogtarget(position) / 2
for (step in 1:nleapfrogsteps){
position <- position + stepsize * invmassdiag * momentum
# position <- position + stepsize * (invmassmatrix %*% momentum)[,1]
if (step != nleapfrogsteps){
momentum <- momentum + stepsize * stan_gradlogtarget(position)
}
}
momentum <- momentum + stepsize * stan_gradlogtarget(position) / 2
proposed_pdf <- stan_logtarget(position)
current_pdf <- state$current_pdf
accept_ratio <- proposed_pdf - current_pdf
# the acceptance ratio also features the "kinetic energy" term of the extended target
accept_ratio <- accept_ratio + (-0.5 * sum(momentum * invmassdiag * momentum)) -
(-0.5 * sum(initial_momentum * invmassdiag * initial_momentum))
accept <- FALSE
if (is.finite(accept_ratio)){
accept <- (log(runif(1)) < accept_ratio)
}
if (accept){
chain_state <- position
current_pdf <- proposed_pdf
}
return(list(chain_state = chain_state, current_pdf = current_pdf, accept = accept))
}
## run HMC for a while to see if we get agreement with the Stan output
niterations <- 2000
state <- rinit()
hmc_chain <- matrix(nrow = niterations, ncol = target_dim)
for (iteration in 1:niterations){
state <- hmc_kernel(state)
hmc_chain[iteration,] <- state$chain_state
}
hist(hmc_chain[1000:niterations,1], nclass = 100, prob = TRUE)
hist(unconstr_posterior[,1], nclass = 100, prob = TRUE, col = rgb(1,0,0,0.5), add = TRUE)
## looks like our custom sampler roughly agrees with stan
## now, coupled HMC kernel
coupled_hmc_kernel <- function(state1, state2){
chain_state1 <- state1$chain_state; current_pdf1 <- state1$current_pdf
chain_state2 <- state2$chain_state; current_pdf2 <- state2$current_pdf
# draw same momentum for two chains
initial_momentum <- sqrt_massdiag * rnorm(target_dim, 0, 1)
position1 <- chain_state1
position2 <- chain_state2
# leap frog integrator
momentum1 <- initial_momentum + stepsize * stan_gradlogtarget(position1) / 2
momentum2 <- initial_momentum + stepsize * stan_gradlogtarget(position2) / 2
for (step in 1:nleapfrogsteps){
position1 <- position1 + stepsize * invmassdiag * momentum1
position2 <- position2 + stepsize * invmassdiag * momentum2
if (step != nleapfrogsteps){
momentum1 <- momentum1 + stepsize * stan_gradlogtarget(position1)
momentum2 <- momentum2 + stepsize * stan_gradlogtarget(position2)
}
}
momentum1 <- momentum1 + stepsize * stan_gradlogtarget(position1) / 2
momentum2 <- momentum2 + stepsize * stan_gradlogtarget(position2) / 2
proposed_pdf1 <- stan_logtarget(position1)
proposed_pdf2 <- stan_logtarget(position2)
accept_ratio1 <- proposed_pdf1 - current_pdf1
# the acceptance ratio also features the "kinetic energy" term of the extended target
accept_ratio1 <- accept_ratio1 + (-0.5 * sum(momentum1 * invmassdiag * momentum1)) -
(-0.5 * sum(initial_momentum * invmassdiag * initial_momentum))
accept_ratio2 <- proposed_pdf2 - current_pdf2
accept_ratio2 <- accept_ratio2 + (-0.5 * sum(momentum2 * invmassdiag * momentum2)) -
(-0.5 * sum(initial_momentum * invmassdiag * initial_momentum))
# same uniform to accept/reject proposed state
logu <- log(runif(1))
accept1 <- FALSE; accept2 <- FALSE
if (is.finite(accept_ratio1)){
accept1 <- (logu < accept_ratio1)
}
if (is.finite(accept_ratio2)){
accept2 <- (logu < accept_ratio2)
}
if (accept1){
chain_state1 <- position1
current_pdf1 <- proposed_pdf1
}
if (accept2){
chain_state2 <- position2
current_pdf2 <- proposed_pdf2
}
return(list(state1 = list(chain_state = chain_state1, current_pdf = current_pdf1),
state2 = list(chain_state = chain_state2, current_pdf = current_pdf2),
identical = FALSE))
}
## run coupled HMC
niterations <- 1000
state1 <- rinit()
state2 <- rinit()
hmc_chain1 <- matrix(nrow = niterations, ncol = target_dim)
hmc_chain2 <- matrix(nrow = niterations, ncol = target_dim)
for (iteration in 1:niterations){
coupled_state <- coupled_hmc_kernel(state1, state2)
state1 <- coupled_state$state1
state2 <- coupled_state$state2
hmc_chain1[iteration,] <- state1$chain_state
hmc_chain2[iteration,] <- state2$chain_state
}
## plot distance between the chains
plot(sapply(1:niterations, function(i) sum(abs(hmc_chain1[i,] - hmc_chain2[i,]))), type = "l",
xlab = "# iterations", ylab = "L1 distance", log = "y")
## The pair of chains seems to contract quickly
## next we will define random walk MH kernels
## which will trigger exact meetings of the chains
## if the chains are very near to one another
## code to sample from reflection max coupling of Normal(0, D), Normal(0, D)
## given position mu1 and mu2
## and where D is diagonal, specified by sqrt(D) and 1/sqrt(D)
## kappa is a tuning parameter
reflmaxcoupling <- function(mu1, mu2, sqrtD, kappa = 1){
dim_ <- length(mu1)
momentum1 <- rnorm(dim_, 0, 1)
momentum2 <- rep(0, dim_)
logu <- log(runif(1))
z <- (mu1 - mu2) / sqrtD
normz <- sqrt(sum(z^2))
evector <- z / normz
edotxi <- sum(evector * momentum1)
if (logu < (dnorm(edotxi + kappa * normz, 0, 1, log = TRUE) - dnorm(edotxi, 0, 1, log = TRUE))){
momentum2 <- momentum1 + kappa * z
samesame <- TRUE
} else {
momentum2 <- momentum1 - 2 * edotxi * evector
samesame <- FALSE
}
momentum1 <- momentum1 * sqrtD
momentum2 <- momentum2 * sqrtD
return(list(momentum1 = momentum1, momentum2 = momentum2, samesame = samesame))
}
## single MH kernel
## with proposal standard deviation set to be very small compared to the target standard deviation
## as we want the proposed moves to be accepted with large probability (which is different
## from the usual MH setting where we would like the acceptance rate to be moderate)
sigma_mhproposal <- sqrt(invmassdiag) / 1e2
mh_kernel <- function(state){
chain_state <- state$chain_state
current_pdf <- state$current_pdf
proposal_value <- chain_state + sigma_mhproposal * rnorm(target_dim)
proposal_pdf <- stan_logtarget(proposal_value)
accept <- (log(runif(1)) < (proposal_pdf - current_pdf))
if (accept){
return(list(chain_state = proposal_value, current_pdf = proposal_pdf))
} else {
return(list(chain_state = chain_state, current_pdf = current_pdf))
}
}
## coupled MH kernel, with "reflection-maximally" coupled proposals
coupled_mh_kernel <- function(state1, state2){
chain_state1 <- state1$chain_state; current_pdf1 <- state1$current_pdf
chain_state2 <- state2$chain_state; current_pdf2 <- state2$current_pdf
proposal_value <- reflmaxcoupling(chain_state1, chain_state2, sigma_mhproposal, kappa = 1)
proposal1 <- chain_state1 + proposal_value$momentum1
proposal_pdf1 <- stan_logtarget(proposal1)
if (proposal_value$samesame){
proposal2 <- proposal1; proposal_pdf2 <- proposal_pdf1
} else {
proposal2 <- chain_state2 + proposal_value$momentum2
proposal_pdf2 <- stan_logtarget(proposal2)
}
logu <- log(runif(1))
accept1 <- FALSE; accept2 <- FALSE
if (is.finite(proposal_pdf1)){
accept1 <- (logu < (proposal_pdf1 - current_pdf1))
}
if (is.finite(proposal_pdf2)){
accept2 <- (logu < (proposal_pdf2 - current_pdf2))
}
if (accept1){
chain_state1 <- proposal1
current_pdf1 <- proposal_pdf1
}
if (accept2){
chain_state2 <- proposal2
current_pdf2 <- proposal_pdf2
}
identical_ <- proposal_value$samesame && accept1 && accept2
return(list(state1 = list(chain_state = chain_state1, current_pdf = current_pdf1),
state2 = list(chain_state = chain_state2, current_pdf = current_pdf2),
identical = identical_))
}
## finally define mixture kernels, equal to MH with probability omega and HMC otherwise
omega <- 1/10
mixture_single_kernel <- function(state){
if (runif(1) < omega){
return(mh_kernel(state))
} else {
return(hmc_kernel(state))
}
}
## coupled mixture kernel
mixture_coupled_kernel <- function(state1, state2){
if (runif(1) < omega){
return(coupled_mh_kernel(state1, state2))
} else {
return(coupled_hmc_kernel(state1, state2))
}
}
## sample meeting times
sample_meetingtime <- function(single_kernel, coupled_kernel, rinit, lag = 1, max_iterations = Inf){
# initialize
state1 <- rinit(); state2 <- rinit()
# move first chain
time <- 0
for (t in 1:lag){
time <- time + 1
state1 <- single_kernel(state1)
}
# move two chains until meeting (or until max_iterations)
meetingtime <- Inf
while (is.infinite(meetingtime) && (time < max_iterations)){
time <- time + 1
# use coupled kernel
coupledstates <- coupled_kernel(state1, state2)
state1 <- coupledstates$state1
state2 <- coupledstates$state2
# check if meeting happens
if (coupledstates$identical) meetingtime <- time
}
return(meetingtime)
}
## now we can draw meeting times
## with this setup, the following takes a few minutes on my machine (with a 3.6GHz 8-core Intel Core i9)
## lag between the chains
lag <- 500
## number of pairs of chains to be produced independently
nrep <- 250
## in parallel using doParallel
library(doParallel)
library(doRNG)
## e.g. using all the cores
registerDoParallel(cores = detectCores())
## generate meeting times
meetingtimes <- foreach(irep = 1:nrep, .combine = c) %dorng% {
sample_meetingtime(mixture_single_kernel, mixture_coupled_kernel, rinit, lag = lag)
}
## Histogram of times to meet, counting from the time the chains start being coupled
hist(meetingtimes - lag, prob = TRUE)
## TV upper bounds as in Biswas, Jacob & Vanetti 2019 https://arxiv.org/abs/1905.09971
tv_upper_bound_estimates <- function(coupling_times, L, t){
return(pmax(0,ceiling((coupling_times-L-t)/L)))
}
## plot TV upper bounds as a function of iteration
timerange <- floor(seq(from = 1, to = 1000, length.out = 100))
tvupperbounds <- sapply(timerange, function(t) mean(tv_upper_bound_estimates(meetingtimes, lag, t)))
plot(timerange, tvupperbounds, type = "l",
xlab = "iteration", ylab = "TV upper bounds")
library(ggplot2)
theme_set(theme_bw()) # a sober theme and various graphical tweaks
theme_update(axis.text.x = element_text(size = 20), axis.text.y = element_text(size = 20),
axis.title.x = element_text(size = 25, margin=margin(20,0,0,0)),
axis.title.y = element_text(size = 25, angle = 90, margin = margin(0,20,0,0)),
legend.text = element_text(size = 20), legend.title = element_text(size = 20), title = element_text(size = 30),
strip.text = element_text(size = 25), strip.background = element_rect(fill="white"), legend.position = "bottom")
g_tvbounds <- qplot(x = timerange, y = tvupperbounds, geom = "line")
g_tvbounds <- g_tvbounds + ylab("TV upper bounds") + xlab("iteration") + ylim(0,1.1)
g_tvbounds
## seems like the chains converge well before 1000 iterations
ggsave(filename = "2019-09-stan-logistic.png", plot = g_tvbounds, width = 7, height = 7)
You can’t perform that action at this time.