Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification on using loo_moment_match() with non-Stan objects #209

Open
wlandau opened this issue Dec 1, 2022 · 4 comments
Open

Clarification on using loo_moment_match() with non-Stan objects #209

wlandau opened this issue Dec 1, 2022 · 4 comments

Comments

@wlandau
Copy link

wlandau commented Dec 1, 2022

I am working on a model averaging problem with very simple models, and I am getting intermittently high Pareto k values even on simple well-behaved simulated datasets. I would like to apply the moment matching correction to both non-longitudinal JAGS models and longitudinal Stan models. The latter case is trivially easy with moment_match = TRUE in loo(), but I do not have a stanfit object in the former case.

Would you help me understand how to use loo_moment_match() in the case where my model fit is a posterior::as_draws_df() data frame with columns for parameters and pointwise log likelihoods? To set up a sufficiently motivating scenario, I converted the roaches example from the vignette into JAGS. I also put constrained priors on the scale parameters for the sake of learning what to do with the unconstrain_pars, log_prob_upars, and log_lik_i_upars arguments of loo_moment_match(). (Is it even appropriate to consider "unconstrained parameters" without HMC?)

    library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
    
    data(roaches, package = "rstanarm")
    roaches$roach1 <- sqrt(roaches$roach1)
    x <- roaches[, c("roach1", "treatment", "senior")]
    data <- list(
        N = nrow(x),
        K = ncol(x),
        x = as.matrix(x),
        y = roaches$y,
        offset = log(roaches[,"exposure2"])
    )
    
    model_text <- "
model {
  for (n in 1:N) {
    y[n] ~ dpois(exp(inprod(x[n,], beta) + intercept + offset[n]))
  }
  for (k in 1:K) {
    beta[k] ~ dnorm(0, 1 / (scale_beta * scale_beta))
  }
  intercept ~ dnorm(0, 1 / (scale_alpha * scale_alpha))
  scale_beta ~ dunif(0, 10)
  scale_alpha ~ dnorm(0, 10) T(0,)
  for (n in 1:N) {
    log_lik[n] <- log(dpois(y[n], exp(inprod(x[n,], beta) + intercept + offset[n])))
  }
}
"
    file <- tempfile()
    writeLines(model_text, file)
    
    tmp <- capture.output({
        model <- rjags::jags.model(
            file = file,
            data = data,
            n.chains = 4,
            n.adapt = 2e3
        )
        stats::update(model, n.iter = 2e3, quiet = TRUE)
        coda <- rjags::coda.samples(
            model = model,
            variable.names = c(
                "beta",
                "intercept",
                "scale_beta",
                "scale_alpha",
                "log_lik"
            ),
            n.iter = 4e3
        )
    })
    
    fit <- posterior::as_draws_df(coda)
    print(fit) # This is the model fit object I can work with.
#> # A draws_df: 4000 iterations, 4 chains, and 268 variables
#>    beta[1] beta[2] beta[3] intercept log_lik[1] log_lik[2] log_lik[3]
#> 1     0.16   -0.55   -0.30       2.5        -19        -16       -2.1
#> 2     0.16   -0.55   -0.29       2.5        -19        -16       -2.1
#> 3     0.16   -0.52   -0.27       2.5        -17        -14       -2.1
#> 4     0.16   -0.55   -0.38       2.5        -16        -14       -2.1
#> 5     0.16   -0.56   -0.25       2.5        -16        -14       -2.1
#> 6     0.16   -0.59   -0.26       2.5        -18        -15       -2.1
#> 7     0.16   -0.60   -0.25       2.5        -19        -16       -2.0
#> 8     0.16   -0.58   -0.34       2.5        -18        -16       -2.0
#> 9     0.16   -0.58   -0.30       2.5        -18        -15       -2.1
#> 10    0.16   -0.60   -0.33       2.5        -17        -15       -2.1
#>    log_lik[4]
#> 1        -2.2
#> 2        -2.2
#> 3        -2.3
#> 4        -2.2
#> 5        -2.2
#> 6        -2.2
#> 7        -2.2
#> 8        -2.2
#> 9        -2.2
#> 10       -2.2
#> # ... with 15990 more draws, and 260 more variables
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
    
    # Convergence looks okay.
    fit %>%
        select(starts_with(c("beta", "intercept", "scale"))) %>%
        posterior::summarize_draws() %>%
        print()
#> Warning: Dropping 'draws_df' class as required metadata was removed.
#> # A tibble: 6 × 10
#>   variable      mean median      sd     mad     q5    q95  rhat ess_bulk ess_t…¹
#>   <chr>        <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>   <dbl>
#> 1 beta[1]      0.161  0.161 0.00193 0.00194  0.158  0.164  1.00    1704.   2993.
#> 2 beta[2]     -0.566 -0.566 0.0248  0.0246  -0.607 -0.524  1.00    3388.   5597.
#> 3 beta[3]     -0.312 -0.312 0.0334  0.0335  -0.368 -0.259  1.00    5957.   8365.
#> 4 intercept    2.52   2.52  0.0260  0.0260   2.48   2.56   1.00    1339.   2777.
#> 5 scale_alpha  0.905  0.889 0.156   0.153    0.673  1.19   1.00    9336.   8030.
#> 6 scale_beta   0.816  0.569 0.835   0.298    0.272  2.20   1.00    2056.    864.
#> # … with abbreviated variable name ¹​ess_tail
    
    # LOO without the moment matching correction is straightforward.
    log_lik <- as.matrix(dplyr::select(fit, tidyselect::starts_with("log_lik")))
#> Warning: Dropping 'draws_df' class as required metadata was removed.
    r_eff <- loo::relative_eff(x = log_lik, chain_id = fit$.chain)
    loo <- loo::loo(x = log_lik, r_eff = r_eff)
#> Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
    
    # But we get high Pareto k values.
    print(loo)
#> 
#> Computed from 16000 by 262 log-likelihood matrix
#> 
#>          Estimate     SE
#> elpd_loo  -5462.1  696.5
#> p_loo       261.3   57.6
#> looic     10924.3 1393.0
#> ------
#> Monte Carlo SE of elpd_loo is NA.
#> 
#> Pareto k diagnostic values:
#>                          Count Pct.    Min. n_eff
#> (-Inf, 0.5]   (good)     239   91.2%   537       
#>  (0.5, 0.7]   (ok)         9    3.4%   76        
#>    (0.7, 1]   (bad)        7    2.7%   11        
#>    (1, Inf)   (very bad)   7    2.7%   1         
#> See help('pareto-k-diagnostic') for details.
    
    # How do I use loo_moment_match() in this situation?
    # loo::loo_moment_match(
    #   x = fit,
    #   post_draws = function(x) as.matrix(x),
    #   log_lik_i = function(x, i) x[[sprintf("log_lik[%s]", i)]],
    #   unconstrain_pars = "???", # Do we even need to consider the unconstrained space for non-HMC MCMC?
    #   log_prob_upars = "???", # Here is where I start to get lost.
    #   log_lik_i_upars = "???" # Same here.
    # )

Created on 2022-12-01 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       macOS Big Sur ... 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/Indiana/Indianapolis
#>  date     2022-12-01
#>  pandoc   2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version date (UTC) lib source
#>  abind            1.4-5   2016-07-21 [1] CRAN (R 4.2.0)
#>  assertthat       0.2.1   2019-03-21 [1] CRAN (R 4.2.0)
#>  backports        1.4.1   2021-12-13 [1] CRAN (R 4.2.0)
#>  checkmate        2.1.0   2022-04-21 [1] CRAN (R 4.2.0)
#>  cli              3.4.1   2022-09-23 [1] CRAN (R 4.2.0)
#>  coda             0.19-4  2020-09-30 [1] CRAN (R 4.2.0)
#>  colorspace       2.0-3   2022-02-21 [1] CRAN (R 4.2.0)
#>  DBI              1.1.3   2022-06-18 [1] CRAN (R 4.2.0)
#>  digest           0.6.30  2022-10-18 [1] CRAN (R 4.2.0)
#>  distributional   0.3.1   2022-09-02 [1] CRAN (R 4.2.0)
#>  dplyr          * 1.0.10  2022-09-01 [1] CRAN (R 4.2.0)
#>  evaluate         0.18    2022-11-07 [1] CRAN (R 4.2.0)
#>  fansi            1.0.3   2022-03-24 [1] CRAN (R 4.2.0)
#>  farver           2.1.1   2022-07-06 [1] CRAN (R 4.2.0)
#>  fastmap          1.1.0   2021-01-25 [1] CRAN (R 4.2.0)
#>  fs               1.5.2   2021-12-08 [1] CRAN (R 4.2.0)
#>  generics         0.1.3   2022-07-05 [1] CRAN (R 4.2.0)
#>  ggplot2          3.4.0   2022-11-04 [1] CRAN (R 4.2.0)
#>  glue             1.6.2   2022-02-24 [1] CRAN (R 4.2.0)
#>  gtable           0.3.1   2022-09-01 [1] CRAN (R 4.2.0)
#>  highr            0.9     2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools        0.5.3   2022-07-18 [1] CRAN (R 4.2.0)
#>  knitr            1.41    2022-11-18 [1] CRAN (R 4.2.0)
#>  lattice          0.20-45 2021-09-22 [1] CRAN (R 4.2.1)
#>  lifecycle        1.0.3   2022-10-07 [1] CRAN (R 4.2.0)
#>  loo              2.5.1   2022-03-24 [1] CRAN (R 4.2.0)
#>  magrittr         2.0.3   2022-03-30 [1] CRAN (R 4.2.0)
#>  matrixStats      0.63.0  2022-11-18 [1] CRAN (R 4.2.0)
#>  munsell          0.5.0   2018-06-12 [1] CRAN (R 4.2.0)
#>  pillar           1.8.1   2022-08-19 [1] CRAN (R 4.2.0)
#>  pkgconfig        2.0.3   2019-09-22 [1] CRAN (R 4.2.0)
#>  posterior        1.3.1   2022-09-06 [1] CRAN (R 4.2.0)
#>  purrr            0.3.5   2022-10-06 [1] CRAN (R 4.2.0)
#>  R.cache          0.16.0  2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3      1.8.2   2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo             1.25.0  2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils          2.12.2  2022-11-11 [1] CRAN (R 4.2.0)
#>  R6               2.5.1   2021-08-19 [1] CRAN (R 4.2.0)
#>  reprex           2.0.2   2022-08-17 [1] CRAN (R 4.2.0)
#>  rjags            4-13    2022-04-19 [1] CRAN (R 4.2.0)
#>  rlang            1.0.6   2022-09-24 [1] CRAN (R 4.2.0)
#>  rmarkdown        2.18    2022-11-09 [1] CRAN (R 4.2.0)
#>  rstudioapi       0.14    2022-08-22 [1] CRAN (R 4.2.0)
#>  scales           1.2.1   2022-08-20 [1] CRAN (R 4.2.0)
#>  sessioninfo      1.2.2   2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi          1.7.8   2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr          1.4.1   2022-08-20 [1] CRAN (R 4.2.0)
#>  styler           1.8.1   2022-11-07 [1] CRAN (R 4.2.0)
#>  tensorA          0.36.2  2020-11-19 [1] CRAN (R 4.2.0)
#>  tibble           3.1.8   2022-07-22 [1] CRAN (R 4.2.0)
#>  tidyselect       1.2.0   2022-10-10 [1] CRAN (R 4.2.1)
#>  utf8             1.2.2   2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs            0.5.1   2022-11-16 [1] CRAN (R 4.2.0)
#>  withr            2.5.0   2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun             0.35    2022-11-16 [1] CRAN (R 4.2.0)
#>  yaml             2.3.6   2022-10-18 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
@wlandau
Copy link
Author

wlandau commented Dec 1, 2022

Are there other convenient ways to make approximate LOO more robust?

@n-kall
Copy link
Contributor

n-kall commented Dec 2, 2022

Hi, you might have some luck with using the generic moment matching functions from https://github.com/topipa/iwmm
You'll need to manually specify the target function or importance weight function but it should work on a matrix object.

@wlandau
Copy link
Author

wlandau commented Dec 7, 2022

Thanks, @n-kall. Is iwmm a generic implementation of https://mc-stan.org/loo/articles/loo2-moment-matching.html, or is the underlying statistical method itself different too?

@n-kall
Copy link
Contributor

n-kall commented Dec 10, 2022

Yes, it is the same underlying mechanism, just generic (i.e. not tied to importance weights for leave-one-out posteriors). Given a log_ratio_fun, the moment_match function will return transformed draws and importance weights (and Pareto-k diagnostic values). k_threshold = 0.7 and split = TRUE would match the loo_moment_match defaults.

If you want to use it for the leave-one-out case, the log_ratio_fun should be a function that returns the negative log likelihood of the left-out observation. See the tests for an example. You'd likely need to wrap it in a loop (for each observation) and use the resulting draws+weights to calculate the elpd or other metrics you're interested in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants