Skip to content

Setting label_points=TRUE in plot() for loo or psis_loo object shifts plotted points to the left #287

@annariha

Description

@annariha

Hi, I encountered an issue when using plot() with label = TRUE for loo and psis_loo objects. When setting label = TRUE, the points below the threshold without labels are shifted to the left, which leads to missing points on the right-hand side of the plot while for some observations several points are plotted (see reprex below).

I think this is due to

loo/R/diagnostics.R

Lines 368 to 369 in b94b2b1

graphics::points(x = if (use_n_eff) n_eff[k < threshold] else k[k < threshold],
col = clrs[k < threshold], pch = 3, cex = .6)
because when I changed this to

graphics::points(x = which(k < threshold), 
                 y = if (use_n_eff) n_eff[k < threshold] else k[k < threshold], 
                 col = clrs[k < threshold], pch = 3, cex = .6)

it seems to fix the issue in all examples that I tried so far (see also an example in the reprex below). Can I make a PR for this, or is this too small or should be discussed more first? Thanks in advance!

library("rstanarm")
library("loo")

# use example from loo-vignette, 'roaches' is included with rstanarm
data(roaches)

# rescale to units of hundreds of roaches
roaches$roach1 <- roaches$roach1 / 100

fit1 <- stan_glm(
  formula = y ~ roach1 + treatment + senior,
  offset = log(exposure2),
  data = roaches,
  family = poisson(link = "log"),
  prior = normal(0, 2.5, autoscale = TRUE),
  prior_intercept = normal(0, 5, autoscale = TRUE),
  seed = 12345
  )

loo1 <- loo(fit1)
#> Warning: Found 17 observations with a pareto_k > 0.7. With this many problematic observations we recommend calling 'kfold' with argument 'K=10' to perform 10-fold cross-validation rather than LOO.

# plot without labels for comparison 
plot(loo1)

# points are shifted to the left when adding labels 
plot(loo1, label_points = TRUE)

# modified version of plot_diagnostic()-function from diagnostics.R
plot_diagnostic <-
  function(k,
           n_eff = NULL,
           threshold = 0.7,
           ...,
           label_points = FALSE,
           main = "PSIS diagnostic plot") {
    use_n_eff <- !is.null(n_eff)
    graphics::plot(
      x = if (use_n_eff) n_eff else k,
      xlab = "Data point",
      # Print ESS as n_eff terms has been deprecated
      ylab = if (use_n_eff) "PSIS ESS" else "Pareto shape k",
      type = "n",
      bty = "l",
      yaxt = "n",
      main = main
    )
    graphics::axis(side = 2, las = 1)
    
    in_range <- function(x, lb_ub) {
      x >= lb_ub[1L] & x <= lb_ub[2L]
    }
    
    if (!use_n_eff) {
      krange <- range(k, na.rm = TRUE)
      breaks <- c(0, threshold, 1)
      hex_clrs <- c("#C79999", "#7C0000")
      ltys <- c(3, 2, 1)
      for (j in seq_along(breaks)) {
        val <- breaks[j]
        if (in_range(val, krange))
          graphics::abline(
            h = val,
            col = ifelse(val == 0, "darkgray", hex_clrs[j - 1]),
            lty = ltys[j],
            lwd = 1
          )
      }
    }
    
    breaks <- c(-Inf, threshold, 1)
    hex_clrs <- c("#6497b1", "#005b96", "#03396c")
    clrs <- ifelse(
      in_range(k, breaks[1:2]),
      hex_clrs[1],
      ifelse(in_range(k, breaks[2:3]), hex_clrs[2], hex_clrs[3])
    )
    if (all(k < threshold) || !label_points) {
      graphics::points(x = if (use_n_eff) n_eff else k,
                       col = clrs, pch = 3, cex = .6)
      return(invisible())
    } else {
      graphics::points(x = which(k < threshold), ### Add this to avoid shifting points 
                       y = if (use_n_eff) n_eff[k < threshold] else k[k < threshold], ### Switch x and y to be able to add y
                       col = clrs[k < threshold], pch = 3, cex = .6)
      sel <- !in_range(k, breaks[1:2])
      dots <- list(...)
      txt_args <- c(
        list(
          x = seq_along(k)[sel],
          y = if (use_n_eff) n_eff[sel] else k[sel],
          labels = seq_along(k)[sel]
        ),
        if (length(dots)) dots
      )
      if (!("adj" %in% names(txt_args))) txt_args$adj <- 2 / 3
      if (!("cex" %in% names(txt_args))) txt_args$cex <- 0.75
      if (!("col" %in% names(txt_args))) txt_args$col <- clrs[sel]
      
      do.call(graphics::text, txt_args)
    }
  }

# plot without labels for comparison 
plot_diagnostic(loo1$diagnostics[["pareto_k"]])

# points are not shifted
plot_diagnostic(loo1$diagnostics[["pareto_k"]], label_points = TRUE)

Created on 2025-06-08 with reprex v2.1.1

Session info

sessionInfo()
#> R version 4.5.0 (2025-04-11)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.2 LTS
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.12.0 
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0  LAPACK version 3.12.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_DK.UTF-8        LC_COLLATE=en_GB.UTF-8    
#>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
#>  [7] LC_PAPER=fi_FI.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: Europe/Helsinki
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] loo_2.8.0       rstanarm_2.32.1 Rcpp_1.0.14    
#> 
#> loaded via a namespace (and not attached):
#>  [1] tidyselect_1.2.1     dplyr_1.1.4          farver_2.1.2        
#>  [4] fastmap_1.2.0        tensorA_0.36.2.1     shinystan_2.6.0     
#>  [7] promises_1.3.3       shinyjs_2.1.0        reprex_2.1.1        
#> [10] digest_0.6.37        mime_0.13            lifecycle_1.0.4     
#> [13] StanHeaders_2.32.10  survival_3.8-3       magrittr_2.0.3      
#> [16] posterior_1.6.1      compiler_4.5.0       rlang_1.1.6         
#> [19] tools_4.5.0          igraph_2.1.4         yaml_2.3.10         
#> [22] knitr_1.50           htmlwidgets_1.6.4    curl_6.2.2          
#> [25] pkgbuild_1.4.7       xml2_1.3.8           plyr_1.8.9          
#> [28] RColorBrewer_1.1-3   dygraphs_1.1.1.6     abind_1.4-8         
#> [31] miniUI_0.1.2         withr_3.0.2          grid_4.5.0          
#> [34] stats4_4.5.0         xts_0.14.1           xtable_1.8-4        
#> [37] inline_0.3.21        ggplot2_3.5.2        scales_1.4.0        
#> [40] gtools_3.9.5         MASS_7.3-65          cli_3.6.5           
#> [43] rmarkdown_2.29       reformulas_0.4.1     generics_0.1.4      
#> [46] RcppParallel_5.1.10  rstudioapi_0.17.1    reshape2_1.4.4      
#> [49] minqa_1.2.8          rstan_2.32.7         stringr_1.5.1       
#> [52] shinythemes_1.2.0    splines_4.5.0        bayesplot_1.12.0    
#> [55] parallel_4.5.0       matrixStats_1.5.0    base64enc_0.1-3     
#> [58] vctrs_0.6.5          boot_1.3-31          Matrix_1.7-3        
#> [61] crosstalk_1.2.1      glue_1.8.0           nloptr_2.2.1        
#> [64] codetools_0.2-20     distributional_0.5.0 DT_0.33             
#> [67] stringi_1.8.7        gtable_0.3.6         later_1.4.2         
#> [70] QuickJSR_1.7.0       lme4_1.1-37          tibble_3.2.1        
#> [73] colourpicker_1.3.0   pillar_1.10.2        htmltools_0.5.8.1   
#> [76] R6_2.6.1             Rdpack_2.6.4         evaluate_1.0.3      
#> [79] shiny_1.10.0         lattice_0.22-5       markdown_2.0        
#> [82] rbibutils_2.3        backports_1.5.0      threejs_0.3.4       
#> [85] httpuv_1.6.16        rstantools_2.4.0     gridExtra_2.3       
#> [88] nlme_3.1-168         checkmate_2.3.2      xfun_0.52           
#> [91] fs_1.6.6             zoo_1.8-14           pkgconfig_2.0.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions