# Small-Sample Benchmark on SuSiE vs SuSiE with Servin-Stephens Prior

We observe in our data analysis that when sample sizes are small 
relative to the number of variants ($n \ll p$),
standard SuSiE can **overfit** by underestimating the residual variance $\sigma^2$.
This leads to inflated Bayes factors and spurious credible sets (CS).
The Servin-Stephens prior integrates out $\sigma^2$ analytically using a
Normal-Inverse-Gamma (NIG) conjugate prior, producing $t$-distributed marginals
that are naturally more conservative but calibrated in scenarios include and beyond small samples.

In this notebook we show a benchmark focused on the small sample situation, using realistic simulations from
eQTL data we analyze as part of FunGen-xQTL project.

### Simulation design

We use **real genotype data** from a small eQTL study ($n = 47$, $p = 7{,}430$)
shipped with susieR as `data_small`, subsampled to $N \in \{20, 30, 47\}$.

**FIXME: this vignette is just using the SuSiE small sample vignette example and repeatedly generating the phenotypes through multiple replicates. Ruixi as discussed let's adopt the scheme into more genes as replicate**

To create realistic noise that reproduces the overfitting
pattern seen on real data, we:

1. Fit **LASSO** (with `lambda.1se` from cross-validation) on the real $(X, y)$
   to obtain an empirical residual $r = y - X\hat{\beta}_{\text{lasso}}$. LASSO is agnostic of SuSiE vs SuSiE-SS and is a sparse regression method so it may best approximate the residual variances left that might be similar to what SuSiE and SuSiE-SS will encounter.
2. Compute $w = U^\top r$, where $U$ contains the left singular vectors of the centered genotype matrix, so each component $w_k$ captures an independent direction of variance along the $k$-th principal axis of $X$. This tells us how the residual is distributed across the directions that the genotype matrix can explain. We expect this projection to be meaningful because the overfitting problem arises precisely because real residual noise concentrates its variance along these same principal axes, making it look like genetic signal to the model.
3. Using a wild bootstrap approach we draw $s_k \in {-1,+1}$ independently for each component and form new noise $\tilde{r} = U,(w \odot s)$. This random sign flip will break any association between the noise term and any columns in $X$ but will retain the same per-eigencomponent variance profile because $w_k^2 = (\pm w_k)^2$ is unchanged due to sign flip, thus retaining the realistic residual variance structure in the original data.
4. Simulate causal signals by randomly drawing columns of $X$ with effect sizes calibrated to a
   target signal-to-noise ratio `h2_sparse`.


### Metrics

| Metric | Definition |
|--------|------------|
| **Power** | (distinct causal variants found in any filtered CS) / (total causal) |
| **Coverage** | (filtered CS containing $\geq 1$ causal) / (total filtered CS) |
| **CS size** | mean number of variants per filtered CS |
| **CS / rep** | total filtered CS / number of replicates |
| $\hat{\sigma}^2$ | estimated residual variance (overfitting diagnostic) |
| $\sum V$ | sum of estimated prior variances across $L$ effects |

### References

- Servin, B. & Stephens, M. (2007). *PLoS Genetics*, 3(7): e114.
- Denault et al (2025). *bioRxiv* doi:10.1101/2025.05.16.654543.


In [1]:
library(susieR)
library(glmnet)
library(future)
library(future.apply)

# --- Configuration ---
ncores    <- max(1, parallelly::availableCores() - 2)
n_rep     <- 100       # replicates per setting
L         <- 10
N_vals    <- c(20, 30, 47)
h2_sparse <- c(0.25, 0.50, 0.80)  # var(signal) / var(y)
L_causal  <- c(1, 2, 3, 4, 5)     # number of causal variants

# Use multisession (PSOCK) to avoid fork + BLAS crashes
plan(multisession, workers = ncores)

# --- Load real data ---
data(data_small)
X_full <- data_small$X
y_full <- data_small$y

cat(sprintf("susieR version : %s\n", packageVersion("susieR")))
cat(sprintf("Workers        : %d  (multisession / PSOCK)\n", ncores))
cat(sprintf("Data           : n = %d, p = %d\n", nrow(X_full), ncol(X_full)))
cat(sprintf("N values       : %s\n", paste(N_vals, collapse = ", ")))
cat(sprintf("h2_sparse      : %s\n", paste(h2_sparse, collapse = ", ")))
cat(sprintf("n_causal       : %s\n", paste(L_causal, collapse = ", ")))
cat(sprintf("Settings       : %d\n", length(N_vals) * length(h2_sparse) * length(L_causal)))
cat(sprintf("Total reps     : %d\n",
    length(N_vals) * length(h2_sparse) * length(L_causal) * n_rep))

Loading required package: Matrix

Loaded glmnet 4.1-10



susieR version : 0.15.51
Workers        : 10  (multisession / PSOCK)
Data           : n = 47, p = 7430
N values       : 20, 30, 47
h2_sparse      : 0.25, 0.5, 0.8
n_causal       : 1, 2, 3, 4, 5
Settings       : 45
Total reps     : 4500


## Wild bootstrap

In [2]:
# Regress y on top-20 PCs of X to estimate the fraction of variance
# explained by genotype.  
Xs_cal    <- scale(X_full, center = TRUE, scale = FALSE)
svd_cal   <- svd(Xs_cal, nu = min(20, nrow(X_full)), nv = 0)
PC20      <- svd_cal$u[, 1:min(20, nrow(X_full))]
R2_20     <- summary(lm(y_full ~ PC20))$r.squared
noise_scale_factor <- sqrt(1 - R2_20)

cat(sprintf("Noise calibration (top-20 PC regression on full data):\n"))
cat(sprintf("  R²(20 PCs) = %.4f  →  noise_scale = sqrt(1 - R²) = %.4f\n",
    R2_20, noise_scale_factor))
cat(sprintf("  noise_sd   = sd(y) × %.4f = %.4f\n",
    noise_scale_factor, sd(y_full) * noise_scale_factor))

# --- LASSO residual ---
get_lasso_residual <- function(X, y, seed = 42) {
  set.seed(seed)
  n <- nrow(X)
  cv_fit    <- cv.glmnet(X, y, alpha = 1, nfolds = min(10, n))
  lasso_fit <- glmnet(X, y, alpha = 1, lambda = cv_fit$lambda.1se)
  as.vector(y - predict(lasso_fit, X))
}

# --- Sign-flip noise generation ---
# Noise is scaled to target_sd (calibrated to real data's noise level).
gen_signflip_noise <- function(U, resid, target_sd, seed = 1) {
  set.seed(seed)
  r_centered <- resid - mean(resid)
  proj  <- as.vector(crossprod(U, r_centered))
  k     <- length(proj)
  signs <- sample(c(-1, 1), k, replace = TRUE)
  noise <- as.vector(U[, 1:k] %*% (proj * signs))
  # Scale to target SD (from PC regression calibration)
  noise * target_sd / sd(noise)
}

# --- Run one replicate ---
run_one_rep <- function(rep_i, n_causal, h2_sp, X, y_real,
                        U_pre = NULL, resid_pre = NULL, L = 10,
                        noise_scale = noise_scale_factor) {
  suppressWarnings({
    n <- nrow(X); p <- ncol(X)

    # Use precomputed LASSO residual + SVD if provided
    if (!is.null(U_pre) && !is.null(resid_pre)) {
      U     <- U_pre
      resid <- resid_pre
    } else {
      seed_sub <- rep_i * 7919 + n
      resid    <- get_lasso_residual(X, y_real, seed = seed_sub)
      Xs       <- scale(X, center = TRUE, scale = FALSE)
      svd_X    <- svd(Xs, nu = n, nv = 0)
      U        <- svd_X$u
    }

    # Sign-flip noise, scaled to match real data's noise level
    target_noise_sd <- sd(y_real) * noise_scale
    seed_flip <- rep_i * 1009 + n * 17 + n_causal * 101 + round(h2_sp * 1000)
    noise <- gen_signflip_noise(U, resid, target_noise_sd, seed = seed_flip)

    # Plant causal signal:
    # h2_sp = var(signal) / var(y), so var(signal) = h2_sp/(1-h2_sp) * var(noise)
    seed_causal <- rep_i * 3331 + n * 23 + n_causal * 107 + round(h2_sp * 1000)
    set.seed(seed_causal)
    causal_idx <- sample(p, n_causal)
    y <- noise
    for (j in causal_idx) {
      bj <- sqrt(h2_sp * var(noise) / ((1 - h2_sp) * n_causal * var(X[, j])))
      y  <- y + X[, j] * bj
    }

    # Fit both methods
    fit_gaus <- tryCatch(
      susie(X, y, L = L, verbose = FALSE),
      error = function(e) NULL)
    fit_ss <- tryCatch(
      susie(X, y, L = L, estimate_residual_method = "Servin_Stephens",
            verbose = FALSE),
      error = function(e) NULL)

    # Extract metrics (filtered CS only, + V diagnostics)
    extract <- function(fit, tag) {
      na_row <- data.frame(
        method = tag, rep = rep_i,
        discovered = NA_real_, n_true_cs = NA_real_, n_cs = NA_real_,
        mean_size = NA_real_, sigma2 = NA_real_,
        mean_V = NA_real_, max_V = NA_real_, sum_V = NA_real_,
        stringsAsFactors = FALSE)
      if (is.null(fit)) return(na_row)

      cs_obj <- susie_get_cs(fit, X = X, min_abs_corr = 0.5)
      cs     <- cs_obj$cs
      ncs    <- length(cs)

      discovered <- 0; n_true_cs <- 0; avg_size <- NA_real_
      if (ncs > 0) {
        discovered <- length(intersect(unique(unlist(cs)), causal_idx))
        n_true_cs  <- sum(sapply(cs, function(s) any(causal_idx %in% s)))
        avg_size   <- mean(sapply(cs, length))
      }

      V_vec <- fit$V
      if (is.null(V_vec)) V_vec <- rep(NA_real_, L)
      if (length(V_vec) == 1) V_vec <- rep(V_vec, L)

      data.frame(
        method     = tag,
        rep        = rep_i,
        discovered = discovered,
        n_true_cs  = n_true_cs,
        n_cs       = ncs,
        mean_size  = avg_size,
        sigma2     = fit$sigma2,
        mean_V     = mean(V_vec, na.rm = TRUE),
        max_V      = max(V_vec, na.rm = TRUE),
        sum_V      = sum(V_vec, na.rm = TRUE),
        stringsAsFactors = FALSE)
    }

    rbind(extract(fit_gaus, "Gaussian"), extract(fit_ss, "SS"))
  })
}

Noise calibration (top-20 PC regression on full data):


  R²(20 PCs) = 0.8701  →  noise_scale = sqrt(1 - R²) = 0.3604


  noise_sd   = sd(y) × 0.3604 = 0.1346


Simulation functions defined.


## Run simulation

Per-setting results are checkpointed as individual `.rds` files.
If all expected files are present, the simulation is skipped and
results are loaded from disk.


In [3]:
outdir <- "/home/gw/GIT/susieR/inst/notebooks/benchmark_results"
dir.create(outdir, showWarnings = FALSE, recursive = TRUE)

all_rds <- file.path(outdir, "all_results.rds")

# --- Validate an RDS checkpoint file ---
valid_rds <- function(path) {
  if (!file.exists(path)) return(FALSE)
  tryCatch({
    x <- readRDS(path)
    is.data.frame(x) && nrow(x) > 0 && "method" %in% names(x) &&
      !any(grepl("fatal|error", x$method, ignore.case = TRUE))
  }, error = function(e) FALSE)
}

# --- Check if all expected per-setting files exist AND are valid ---
expected_tags <- c()
for (N in N_vals) {
  for (h2 in h2_sparse) {
    for (nc in L_causal) {
      expected_tags <- c(expected_tags,
                         sprintf("N%d_h2%03d_nc%d", N, round(h2 * 100), nc))
    }
  }
}
expected_files <- file.path(outdir, paste0(expected_tags, ".rds"))
all_valid <- all(sapply(expected_files, valid_rds))

if (all_valid && valid_rds(all_rds)) {
  cat("All expected outputs found and valid. Loading pre-computed results.\n")
  results <- readRDS(all_rds)
  cat(sprintf("Loaded: %d rows  (%d settings x %d reps x 2 methods)\n",
      nrow(results), length(expected_files), n_rep))
} else {
  n_valid   <- sum(sapply(expected_files, valid_rds))
  n_missing <- length(expected_files) - n_valid
  cat(sprintf("Valid: %d / %d settings. Running %d remaining ...\n",
      n_valid, length(expected_files), n_missing))

  all_results <- list()
  t_total <- proc.time()
  n_full  <- nrow(X_full)

  # Capture noise_scale for explicit passing to workers
  ns <- noise_scale_factor

  # ── Outer loop over N: precompute LASSO + SVD in main process ──
  for (N in N_vals) {
    cat(sprintf("\n=== N = %d: precomputing LASSO residuals + SVD ===\n", N))
    t_pre <- proc.time()

    keep_list  <- vector("list", n_rep)
    U_list     <- vector("list", n_rep)
    resid_list <- vector("list", n_rep)

    if (N == n_full) {
      resid0 <- get_lasso_residual(X_full, y_full, seed = 42)
      Xs0    <- scale(X_full, center = TRUE, scale = FALSE)
      svd0   <- svd(Xs0, nu = N, nv = 0)
      for (i in seq_len(n_rep)) {
        keep_list[[i]]  <- seq_len(n_full)
        U_list[[i]]     <- svd0$u
        resid_list[[i]] <- resid0
      }
      cat(sprintf("  Full sample: LASSO var(resid) = %.6f\n", var(resid0)))
    } else {
      for (i in seq_len(n_rep)) {
        set.seed(i * 7919 + N)
        keep <- sample(n_full, N)
        keep_list[[i]] <- keep
        Xi <- X_full[keep, ]
        yi <- y_full[keep]
        resid_list[[i]] <- get_lasso_residual(Xi, yi, seed = i * 7919 + N)
        Xs <- scale(Xi, center = TRUE, scale = FALSE)
        svd_i <- svd(Xs, nu = N, nv = 0)
        U_list[[i]] <- svd_i$u
      }
      cat(sprintf("  %d subsamples: mean var(resid) = %.6f\n",
          n_rep, mean(sapply(resid_list, var))))
    }
    cat(sprintf("  Precompute time: %.0f sec\n", (proc.time() - t_pre)[3]))

    # ── Inner loops: future_lapply workers do sign-flip + SuSiE ──
    for (h2 in h2_sparse) {
      for (nc in L_causal) {
        tag      <- sprintf("N%d_h2%03d_nc%d", N, round(h2 * 100), nc)
        rds_path <- file.path(outdir, paste0(tag, ".rds"))

        if (valid_rds(rds_path)) {
          cat(sprintf("[SKIP] %s\n", tag))
          res <- readRDS(rds_path)
        } else {
          if (file.exists(rds_path)) unlink(rds_path)  # remove corrupt
          cat(sprintf("[RUN]  %s ... ", tag))
          t0 <- proc.time()

          res <- do.call(rbind, future_lapply(seq_len(n_rep), function(i) {
            keep <- keep_list[[i]]
            run_one_rep(i, nc, h2, X_full[keep, ], y_full[keep],
                        U_pre = U_list[[i]], resid_pre = resid_list[[i]],
                        L = L, noise_scale = ns)
          }, future.seed = TRUE))

          res$N         <- N
          res$h2_sparse <- h2
          res$n_causal  <- nc
          saveRDS(res, rds_path)
          cat(sprintf("done  (%.0f sec)\n", (proc.time() - t0)[3]))
        }
        all_results[[length(all_results) + 1]] <- res
      }
    }
  }

  results <- do.call(rbind, all_results)
  saveRDS(results, all_rds)
  cat(sprintf("\nTotal: %d rows, %.1f min\n", nrow(results),
      (proc.time() - t_total)[3] / 60))
  cat(sprintf("Results saved to %s\n", all_rds))
}

All expected outputs found and valid. Loading pre-computed results.
Loaded: 9000 rows  (45 settings x 100 reps x 2 methods)


## Aggregate results

In [4]:
# Ensure numeric types
for (col in c("discovered", "n_true_cs", "n_cs", "mean_size",
              "sigma2", "mean_V", "max_V", "sum_V",
              "N", "h2_sparse", "n_causal")) {
  if (col %in% names(results))
    results[[col]] <- as.numeric(results[[col]])
}

# Aggregate across replicates
groups   <- unique(results[, c("method", "N", "h2_sparse", "n_causal")])
agg_list <- vector("list", nrow(groups))

for (gi in seq_len(nrow(groups))) {
  m  <- groups$method[gi]
  nn <- groups$N[gi]
  h2 <- groups$h2_sparse[gi]
  nc <- groups$n_causal[gi]
  df <- results[results$method == m & results$N == nn &
                results$h2_sparse == h2 & results$n_causal == nc, ]
  if (nrow(df) == 0) next

  nr           <- nrow(df)
  total_causal <- nc * nr
  s  <- function(x) sum(x, na.rm = TRUE)
  mn <- function(x) mean(x, na.rm = TRUE)

  agg_list[[gi]] <- data.frame(
    method      = m,
    N           = nn,
    h2_sparse   = h2,
    n_causal    = nc,
    power       = s(df$discovered) / total_causal,
    coverage    = ifelse(s(df$n_cs) > 0, s(df$n_true_cs) / s(df$n_cs), NA),
    total_cs    = s(df$n_cs),
    cs_size     = mn(df$mean_size),
    mean_sigma2 = mn(df$sigma2),
    mean_V      = mn(df$mean_V),
    max_V       = mn(df$max_V),
    sum_V       = mn(df$sum_V),
    stringsAsFactors = FALSE)
}

agg <- do.call(rbind, agg_list)
rownames(agg) <- NULL

# Factor columns for plotting
agg$BF        <- factor(agg$method, levels = c("SS", "Gaussian"))
agg$L         <- agg$n_causal
agg$n         <- agg$N
agg$cs_per_rep <- agg$total_cs / n_rep

cat(sprintf("Aggregated: %d rows\n\n", nrow(agg)))

# Print summary table
cat(sprintf("%-10s %3s %5s %3s  %6s %6s %6s %6s %8s %8s\n",
    "method", "N", "h2", "nc", "power", "cover", "cs/rep", "size", "sigma2", "sum_V"))
cat(paste(rep("-", 82), collapse = ""), "\n")
for (i in seq_len(nrow(agg))) {
  a <- agg[i, ]
  cat(sprintf("%-10s %3d %5.2f %3d  %6.3f %6.3f %6.2f %6.1f %8.4f %8.4f\n",
      a$method, a$N, a$h2_sparse, a$n_causal,
      a$power,
      ifelse(is.na(a$coverage), 0, a$coverage),
      a$cs_per_rep,
      ifelse(is.na(a$cs_size), 0, a$cs_size),
      a$mean_sigma2,
      a$sum_V))
}


Aggregated: 90 rows



method       N    h2  nc   power  cover cs/rep   size   sigma2    sum_V


---------------------------------------------------------------------------------- 


Gaussian    20  0.25   1   0.000  0.000   0.15   28.1   0.0148   0.0072
SS          20  0.25   1   0.000  0.000   0.00    0.0   0.4074   0.0010
Gaussian    20  0.25   2   0.010  0.154   0.13   12.2   0.0145   0.0070
SS          20  0.25   2   0.000  0.000   0.00    0.0   0.4324   0.0009
Gaussian    20  0.25   3   0.010  0.176   0.17   17.7   0.0143   0.0083
SS          20  0.25   3   0.000  0.000   0.00    0.0   0.4040   0.0010
Gaussian    20  0.25   4   0.005  0.250   0.08   20.5   0.0150   0.0062
SS          20  0.25   4   0.000  0.000   0.00    0.0   0.4520   0.0009
Gaussian    20  0.25   5   0.000  0.000   0.11   23.1   0.0146   0.0069
SS          20  0.25   5   0.000  0.000   0.00    0.0   0.4608   0.0009
Gaussian    20  0.50   1   0.120  0.571   0.21   27.4   0.0183   0.0157
SS          20  0.50   1   0.000  0.000   0.00    0.0   0.1975   0.0026
Gaussian    20  0.50   2   0.010  0.200   0.10   23.0   0.0196   0.0119
SS          20  0.50   2   0.000  0.000   0.00    0.0   0.2254  

## Plots

Grid layout: columns = sample sizes ($N = 20, 30, 47$),
rows = metrics.  All CS metrics use the purity-filtered credible sets
(`min_abs_corr = 0.5`).  Figures are saved to the `benchmark_results/` directory.


In [5]:
library(ggplot2)
library(cowplot)
library(gridExtra)
library(grid)

figdir <- "/home/gw/GIT/susieR/inst/notebooks/benchmark_results"

methods_colors <- c("SS" = "#D41159", "Gaussian" = "#1A85FF")

perf_theme <- theme_cowplot(font_size = 16) +
  theme(
    legend.position  = "none",
    panel.grid.major.y = element_line(color = "gray80"),
    panel.grid.major.x = element_blank(),
    panel.grid.minor   = element_blank(),
    axis.line   = element_line(linewidth = 1, color = "black"),
    axis.ticks  = element_line(linewidth = 1, color = "black"),
    axis.ticks.length = unit(0.25, "cm"),
    plot.margin = margin(t = 2, r = 2, b = 2, l = 2, unit = "mm"),
    axis.text   = element_text(size = 14, face = "bold"),
    axis.title  = element_text(size = 16, face = "bold"),
    plot.title  = element_text(size = 16, face = "bold")
  )
dot_size <- 4
cat("Plot theme set.\n")


Plot theme set.


In [6]:
for (h2 in h2_sparse) {
  d <- agg[agg$h2_sparse == h2, ]

  plots <- list()
  for (ni in seq_along(N_vals)) {
    nn <- N_vals[ni]
    dd <- d[d$N == nn, ]
    ylab_fn <- function(lab) if (ni == 1) lab else ""

    # ── Coverage ──
    p_cov <- ggplot(dd, aes(x = as.factor(L), y = coverage, col = BF)) +
      geom_point(size = dot_size) +
      geom_hline(yintercept = 0.95, linetype = "dashed", linewidth = 1) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(
        min(c(0.5, agg$coverage), na.rm = TRUE) - 0.02, 1)) +
      ylab(ylab_fn("Coverage")) + xlab("") + perf_theme
    plots[[paste0("cov_", nn)]] <- p_cov

    # ── Power (log scale) ──
    pow_floor <- 0.001
    p_pow <- ggplot(dd, aes(x = as.factor(L),
                            y = pmax(power, pow_floor), col = BF)) +
      geom_point(size = dot_size) +
      scale_color_manual(values = methods_colors) +
      scale_y_log10(limits = c(pow_floor, 1),
                    breaks = c(0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 1)) +
      ylab(ylab_fn("Power")) + xlab("") + perf_theme
    plots[[paste0("pow_", nn)]] <- p_pow

    # ── CS per replicate ──
    p_ncs <- ggplot(dd, aes(x = as.factor(L), y = cs_per_rep, col = BF)) +
      geom_point(size = dot_size) +
      scale_color_manual(values = methods_colors) +
      ylim(c(0, max(agg$cs_per_rep, na.rm = TRUE) * 1.05)) +
      ylab(ylab_fn("CS / replicate")) + xlab("") + perf_theme
    plots[[paste0("ncs_", nn)]] <- p_ncs

    # ── CS size ──
    cs_range <- range(agg$cs_size, na.rm = TRUE)
    p_size <- ggplot(dd, aes(x = as.factor(L), y = cs_size, col = BF)) +
      geom_point(size = dot_size) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(0, cs_range[2] * 1.05)) +
      ylab(ylab_fn("CS size")) + xlab("") + perf_theme
    plots[[paste0("size_", nn)]] <- p_size

    # ── sum V (prior variance diagnostic) ──
    p_sumv <- ggplot(dd, aes(x = as.factor(L), y = sum_V, col = BF)) +
      geom_point(size = dot_size) +
      scale_color_manual(values = methods_colors) +
      ylim(c(0, max(agg$sum_V, na.rm = TRUE) * 1.05)) +
      ylab(ylab_fn(expression(Sigma~V))) + xlab("") + perf_theme
    plots[[paste0("sumv_", nn)]] <- p_sumv

    # ── sigma2 ──
    p_sig <- ggplot(dd, aes(x = as.factor(L), y = mean_sigma2, col = BF)) +
      geom_point(size = dot_size) +
      scale_color_manual(values = methods_colors) +
      ylim(c(0, max(agg$mean_sigma2, na.rm = TRUE) * 1.1)) +
      ylab(ylab_fn(expression(hat(sigma)^2))) +
      xlab("Number of causal variants") + perf_theme
    plots[[paste0("sig_", nn)]] <- p_sig
  }

  # Column headers
  titles <- lapply(N_vals, function(nn)
    textGrob(label = paste0("N = ", nn),
             gp = gpar(fontsize = 16, fontface = "bold")))

  metric_rows <- c("cov", "pow", "ncs", "size", "sumv", "sig")
  plot_grobs <- lapply(metric_rows, function(r)
    lapply(N_vals, function(nn) plots[[paste0(r, "_", nn)]]))
  plot_grobs <- do.call(c, plot_grobs)

  fig <- arrangeGrob(
    arrangeGrob(grobs = titles, ncol = 3),
    arrangeGrob(grobs = plot_grobs, ncol = 3, nrow = 6),
    heights = c(0.04, 1),
    top = textGrob(sprintf("h2_sparse = %d%%", round(h2 * 100)),
                   gp = gpar(fontsize = 18, fontface = "bold")))

  fn <- sprintf("benchmark_h2%03d", round(h2 * 100))
  pdf(file.path(figdir, paste0(fn, ".pdf")), width = 16, height = 20)
  grid.draw(fig); dev.off()
  png(file.path(figdir, paste0(fn, ".png")), width = 16, height = 20,
      units = "in", res = 150)
  grid.draw(fig); dev.off()
  cat(sprintf("Saved: %s.{pdf,png}\n", fn))
}


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 8 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 1 row containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 8 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 1 row containing missing values or values outside the scale range
(`geom_point()`).”


Saved: benchmark_h2025.{pdf,png}


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


“[1m[22mRemoved 5 rows containing missing values or values outside the scale range
(`geom_point()`).”


Saved: benchmark_h2050.{pdf,png}
Saved: benchmark_h2080.{pdf,png}


In [7]:
# --- Standalone legend ---
legend_df <- data.frame(
  x = c(1, 1), y = c(1, 2),
  grp = factor(c("SS (Servin-Stephens)", "Gaussian (standard)"),
               levels = c("SS (Servin-Stephens)", "Gaussian (standard)")))
legend_colors <- c("SS (Servin-Stephens)" = "#D41159",
                   "Gaussian (standard)"  = "#1A85FF")
p_leg <- ggplot(legend_df, aes(x, y, col = grp)) +
  geom_point(size = 5) +
  scale_color_manual(values = legend_colors, name = "Method") +
  theme_void() +
  theme(legend.position  = "bottom",
        legend.text  = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"))
legend_grob <- cowplot::get_legend(p_leg)

pdf(file.path(figdir, "benchmark_legend.pdf"), width = 8, height = 1)
grid::grid.draw(legend_grob); dev.off()
png(file.path(figdir, "benchmark_legend.png"), width = 8, height = 1,
    units = "in", res = 150)
grid::grid.draw(legend_grob); dev.off()
cat("Saved: benchmark_legend.{pdf,png}\n")


Saved: benchmark_legend.{pdf,png}
