Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moment matching: normalizing constant of PSIS weights doesn't get updated #166

Closed
fweber144 opened this issue Feb 4, 2021 · 1 comment
Closed
Labels

Comments

@fweber144
Copy link
Contributor

When performing moment matching with save_psis = TRUE, it seems like the normalizing constant in attr(loo$psis_object, "norm_const_log") does not get updated (even though the PSIS weights are updated).

I think the issue is with these lines

loo/R/loo_moment_matching.R

Lines 149 to 155 in 51fb2d0

if (!is.null(loo$psis_object)) {
loo$psis_object$log_weights[, i] <- mm_list[[ii]]$lwi
}
}
if (!is.null(loo$psis_object)) {
loo$psis_object$diagnostics <- loo$diagnostics
}

where log_weights gets updated, but their normalizing constant in attr(loo$psis_object, "norm_const_log") not. I'll create a corresponding pull request for this issue. Here is a reprex:

options(mc.cores = 4)

### From <https://mc-stan.org/loo/articles/loo2-moment-matching.html>, but
### with `save_psis = TRUE`:
stancode <- "
data {
  int<lower=1> K;
  int<lower=1> N;
  matrix[N,K] x;
  int y[N];
  vector[N] offset;

  real beta_prior_scale;
  real alpha_prior_scale;
}
parameters {
  vector[K] beta;
  real intercept;
}
model {
  y ~ poisson(exp(x * beta + intercept + offset));
  beta ~ normal(0,beta_prior_scale);
  intercept ~ normal(0,alpha_prior_scale);
}
generated quantities {
  vector[N] log_lik;
  for (n in 1:N)
    log_lik[n] = poisson_lpmf(y[n] | exp(x[n] * beta + intercept + offset[n]));
}
"
library("rstan")
library("loo")
seed <- 9547
set.seed(seed)
# Prepare data
data(roaches, package = "rstanarm")
roaches$roach1 <- sqrt(roaches$roach1)
y <- roaches$y
x <- roaches[,c("roach1", "treatment", "senior")]
offset <- log(roaches[,"exposure2"])
n <- dim(x)[1]
k <- dim(x)[2]

standata <- list(N = n, K = k, x = as.matrix(x), y = y, offset = offset, beta_prior_scale = 2.5, alpha_prior_scale = 5.0)

# Compile
stanmodel <- stan_model(model_code = stancode)

# Fit model
fit <- sampling(stanmodel, data = standata, seed = seed, refresh = 0)
print(fit, pars = "beta")

loo1 <- loo(fit, save_psis = TRUE)

# available in rstan >= 2.21
loo2 <- loo(fit, moment_match = TRUE, save_psis = TRUE)
### 

# We now have:
identical(
  attr(loo1$psis_object, "norm_const_log"),
  attr(loo2$psis_object, "norm_const_log")
)
# --> TRUE
# We should have (I think):
all.equal(
  attr(loo2$psis_object, "norm_const_log"),
  matrixStats::colLogSumExps(loo2$psis_object$log_weights)
)
# --> "Mean relative difference: 1"

Session info:

R version 4.0.3 (2020-10-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.2 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale:
[...]

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] loo_2.4.1               rstan_2.26.0.9000       StanHeaders_2.26.0.9000

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.6         compiler_4.0.3     pillar_1.4.7       prettyunits_1.1.1  tools_4.0.3        pkgbuild_1.2.0    
 [7] checkmate_2.0.0    jsonlite_1.7.2     lifecycle_0.2.0    tibble_3.0.6       gtable_0.3.0       pkgconfig_2.0.3   
[13] rlang_0.4.10       DBI_1.1.1          cli_2.3.0          parallel_4.0.3     curl_4.3           gridExtra_2.3     
[19] dplyr_1.0.4        generics_0.1.0     vctrs_0.3.6        stats4_4.0.3       grid_4.0.3         tidyselect_1.1.0  
[25] glue_1.4.2         inline_0.3.17      R6_2.5.0           processx_3.4.5     ggplot2_3.3.3      callr_3.5.1       
[31] purrr_0.3.4        magrittr_2.0.1     backports_1.2.1    codetools_0.2-18   matrixStats_0.58.0 scales_1.1.1      
[37] ps_1.5.0           ellipsis_0.3.1     assertthat_0.2.1   colorspace_2.0-0   V8_3.4.0           RcppParallel_5.0.2
[43] munsell_0.5.0      crayon_1.4.0 
@topipa
Copy link
Collaborator

topipa commented Feb 4, 2021

Thanks for reporting the issue! It is indeed true that the attributes of the psis object are not modified as they should.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants