# Small-Sample Benchmark SuSiE vs SuSiE-SS

We observe in our data analysis that when sample sizes are small
relative to the number of variants ($n \ll p$),
standard SuSiE can **overfit** by underestimating the residual variance $\sigma^2$.
This leads to inflated Bayes factors and spurious credible sets (CS).
The Servin-Stephens prior integrates out $\sigma^2$ analytically using a
Normal-Inverse-Gamma (NIG) conjugate prior, producing $t$-distributed marginals
that are naturally more conservative but calibrated in scenarios including and beyond small samples.

In this notebook we show a benchmark focused on the small sample situation, using realistic simulations from
eQTL data we analyze.

### Data

This particular example use real genotype and expression data from the Thyroid FMO2 locus,
part of the GTEx project. The dataset contains $n = 574$ samples,
$p = 7{,}651$ variants in a 1 Mb window, and 68 covariates
(5 genotype PCs, 60 inferred covariates, PCR method, platform, and sex).
We first regress covariates out of both $X$ and $y$ on the full cohort,
then subsample $N \in \{30, 50, 70, 100\}$ individuals per replicate.

### Simulation design

To create realistic noise that reproduces the overfitting
pattern seen on real data, we:

1. Fit **LASSO** (with `lambda.1se` from cross-validation) on the subsampled $(X, y)$
   to obtain an empirical residual $r = y - X\hat{\beta}_{\text{lasso}}$.
   LASSO is agnostic of SuSiE vs SuSiE-SS and is a sparse regression method,
   so it may best approximate the residual variances left that might be similar
   to what SuSiE and SuSiE-SS will encounter.
2. Compute $w = U^{\top} r$, where $U$ contains the left singular vectors
   of the centered genotype matrix, so each component $w_k$ captures an
   independent direction of variance along the $k$-th principal axis of $X$.
   This tells us how the residual is distributed across the directions that the
   genotype matrix can explain. We expect this projection to be meaningful
   because the overfitting problem arises precisely because real residual noise
   concentrates its variance along these same principal axes, making it look
   like genetic signal to the model.
3. Using a wild bootstrap approach, we draw $s_k \in \lbrace -1, +1 \rbrace$
   independently for each component and form new noise
   $\tilde{r} = U(w \odot s)$. This random sign flip will break any
   association between the noise term and any columns in $X$ but will retain
   the same per-eigencomponent variance profile because
   $w_k^2 = (\pm w_k)^2$ is unchanged due to sign flip, thus retaining
   the realistic residual variance structure in the original data.
4. The sign-flip noise is rescaled to a target standard deviation calibrated
   from the full Z-adjusted cohort. We regress $y$ on the top 20 PCs of $X$
   (on all 574 samples after covariate adjustment) and set
   `noise_scale = sqrt(1 - R2)`. This captures the fraction of phenotypic
   variance not explained by genotype and avoids overfitting that would occur
   if calibration were done on the small subsample.
5. Simulate causal signals by randomly drawing columns of $X$ with effect
   sizes calibrated to a target signal-to-noise ratio `h2_sparse`.

### Metrics

| Metric | Definition |
|--------|------------|
| **Power** | (distinct causal variants found in any filtered CS) / (total causal) |
| **Coverage** | (filtered CS containing $\geq 1$ causal) / (total filtered CS) |
| **CS size** | mean number of variants per filtered CS |
| **CS / rep** | total filtered CS / number of replicates |
| $\hat{\sigma}^2$ | estimated residual variance (overfitting diagnostic) |
| $\sum V$ | sum of estimated prior variances across $L$ effects |

### References

- Servin, B. & Stephens, M. (2007). *PLoS Genetics*, 3(7): e114.
- Denault et al (2025). *bioRxiv* doi:10.1101/2025.05.16.654543.

In [None]:
library(susieR)
library(glmnet)
library(digest)
library(future)
library(future.apply)

# --- Configuration ---
ncores    <- max(1, parallelly::availableCores() - 2)
n_rep     <- 200       # replicates per setting
L         <- 10
N_vals    <- c(30, 50, 70, 100)
h2_sparse <- c(0.25, 0.50, 0.75)
L_causal  <- c(1, 2, 3, 4, 5)

# Use multisession (PSOCK) to avoid fork + BLAS crashes
plan(multisession, workers = ncores)

cat(sprintf("susieR version : %s\n", packageVersion("susieR")))
cat(sprintf("Workers        : %d  (multisession / PSOCK)\n", ncores))
cat(sprintf("N values       : %s\n", paste(N_vals, collapse = ", ")))
cat(sprintf("h2_sparse      : %s\n", paste(h2_sparse, collapse = ", ")))
cat(sprintf("n_causal       : %s\n", paste(L_causal, collapse = ", ")))
cat(sprintf("Settings       : %d\n", length(N_vals) * length(h2_sparse) * length(L_causal)))
cat(sprintf("Total reps     : %d\n",
    length(N_vals) * length(h2_sparse) * length(L_causal) * n_rep))

In [2]:
# --- Load Thyroid FMO2 data ---
dat <- readRDS("Thyroid.FMO2.1Mb.RDS")
X_raw <- dat$X      # 574 x 7651 integer genotype (0/1/2)
y_raw <- dat$y      # 574 normalized expression
Z     <- dat$Z      # 574 x 68 covariates

cat(sprintf("Raw data: n = %d, p = %d, ncov = %d\n",
    nrow(X_raw), ncol(X_raw), ncol(Z)))
cat(sprintf("var(y_raw) = %.4f\n", var(y_raw)))

# --- Full-cohort covariate adjustment ---
# Regress Z out of both y and X via hat matrix H = Z1 (Z1'Z1)^{-1} Z1'
# where Z1 = [1, Z] includes an intercept
Z1 <- cbind(1, Z)
H  <- Z1 %*% solve(crossprod(Z1), t(Z1))

y_full <- as.vector(y_raw - H %*% y_raw)
X_full <- X_raw - H %*% X_raw

# Summary
R2_Z <- 1 - var(y_full) / var(y_raw)
cat(sprintf("\nAfter Z adjustment (full cohort, n = %d):\n", length(y_full)))
cat(sprintf("  var(y_raw)  = %.4f\n", var(y_raw)))
cat(sprintf("  var(y_adj)  = %.4f\n", var(y_full)))
cat(sprintf("  R2(y ~ Z)   = %.4f\n", R2_Z))
cat(sprintf("  n = %d, p = %d\n", nrow(X_full), ncol(X_full)))

Raw data: n = 574, p = 7651, ncov = 68


var(y_raw) = 0.9827



After Z adjustment (full cohort, n = 574):


  var(y_raw)  = 0.9827


  var(y_adj)  = 0.3317


  R2(y ~ Z)   = 0.6625


  n = 574, p = 7651


## Sign-flip noise model

The sign-flip (wild bootstrap) procedure generates realistic null noise that preserves the variance structure of LASSO residuals projected onto the principal axes of $X$, while breaking any true genotype-phenotype association. The noise level is calibrated once on the full Z-adjusted cohort using a top-20 PC regression, then applied consistently across all subsample sizes.

In [None]:
# --- Noise calibration on full Z-adjusted data ---
Xs_cal    <- scale(X_full, center = TRUE, scale = FALSE)
svd_cal   <- svd(Xs_cal, nu = min(20, nrow(X_full)), nv = 0)
PC20      <- svd_cal$u[, 1:min(20, nrow(X_full))]
R2_20     <- summary(lm(y_full ~ PC20))$r.squared
noise_scale_factor <- sqrt(1 - R2_20)

cat(sprintf("Noise calibration (top-20 PC regression on full Z-adjusted data):\n"))
cat(sprintf("  n = %d (full cohort after Z adjustment)\n", nrow(X_full)))
cat(sprintf("  R2(20 PCs) = %.4f  ->  noise_scale = sqrt(1 - R2) = %.4f\n",
    R2_20, noise_scale_factor))
cat(sprintf("  For a subsample with sd(y) = s, noise_sd = s * %.4f\n",
    noise_scale_factor))

# ============================================================
# Seed management
# ============================================================
# Each random operation uses a deterministic seed derived from the
# simulation coordinates (rep_id, N, h2, n_causal).  This ensures:
#
#   1. Reproducibility — rerunning the same rep_id always yields
#      identical results.
#   2. No collisions — different (rep, setting) pairs get different
#      seeds because each "purpose" uses a different set of prime
#      multipliers.
#   3. Safe extension — adding reps 11-200 to an existing 1-10 run
#      produces genuinely new draws (different rep_ids -> different seeds).
#      The checkpoint system (below) tracks completed rep_ids and never
#      reruns them, preventing accidental duplication.

make_seed <- function(rep_i, N,
                      purpose = c("subsample", "flip", "causal"),
                      h2 = 0, nc = 0) {
  purpose <- match.arg(purpose)
  base <- switch(purpose,
    subsample = rep_i * 7919L + N,
    flip      = rep_i * 1009L + N * 17L + nc * 101L + round(h2 * 1000),
    causal    = rep_i * 3331L + N * 23L + nc * 107L + round(h2 * 1000)
  )
  as.integer(abs(base) %% (.Machine$integer.max - 1L) + 1L)
}

# ============================================================
# Checkpoint utilities
# ============================================================
# The checkpoint system uses an MD5 hash of all simulation parameters
# (data fingerprint + design settings) to detect configuration changes.
# Results are stored per-setting as individual .rds files in outdir.
#
# Workflow:
#   1. compute_config_md5()  — hash all parameters
#   2. checkpoint_init()     — verify existing results or clear
#   3. checkpoint_completed_reps()  — which rep_ids are done?
#   4. checkpoint_save()     — merge new results, deduplicate
#   5. checkpoint_save_meta() — write _meta.rds after completion

#' Compute MD5 signature of all simulation parameters.
#' Any change in data, design, or noise calibration produces a
#' different hash, triggering automatic invalidation.
compute_config_md5 <- function(X, y, N_vals, h2_sparse, L_causal,
                               L, noise_scale) {
  nr <- min(50, nrow(X)); nc_dim <- min(50, ncol(X))
  config <- list(
    data_nrow   = nrow(X),
    data_ncol   = ncol(X),
    data_corner = sum(X[1:nr, 1:nc_dim]),
    var_y       = round(var(as.vector(y)), 8),
    noise_scale = round(noise_scale, 8),
    L           = as.integer(L),
    N_vals      = sort(as.integer(N_vals)),
    h2_sparse   = sort(round(h2_sparse, 6)),
    L_causal    = sort(as.integer(L_causal))
  )
  md5 <- digest(config, algo = "md5")
  list(config = config, md5 = md5)
}

#' Initialize checkpoint directory.
#' Returns list(status, meta) where status is "match" or "fresh".
#' On mismatch, all existing .rds files are deleted.
checkpoint_init <- function(outdir, config_sig) {
  dir.create(outdir, showWarnings = FALSE, recursive = TRUE)
  meta_path <- file.path(outdir, "_meta.rds")

  if (file.exists(meta_path)) {
    old_meta <- tryCatch(readRDS(meta_path), error = function(e) NULL)
    if (!is.null(old_meta) && !is.null(old_meta$md5) &&
        identical(old_meta$md5, config_sig$md5)) {
      cat(sprintf("Config MD5 verified: %s\n", config_sig$md5))
      cat(sprintf("  Previous run: n_rep=%d, timestamp=%s\n",
          old_meta$n_rep, old_meta$timestamp))
      return(list(status = "match", meta = old_meta))
    }
    old_md5 <- if (!is.null(old_meta$md5)) old_meta$md5 else "(missing/corrupt)"
    cat(sprintf("Config MD5 MISMATCH — clearing old results.\n"))
    cat(sprintf("  Old: %s\n  New: %s\n", old_md5, config_sig$md5))
    old_files <- list.files(outdir, pattern = "[.]rds$", full.names = TRUE)
    if (length(old_files) > 0) {
      file.remove(old_files)
      cat(sprintf("  Removed %d old file(s).\n", length(old_files)))
    }
    return(list(status = "fresh", meta = NULL))
  }

  cat(sprintf("No existing checkpoint. Config MD5: %s\n", config_sig$md5))
  list(status = "fresh", meta = NULL)
}

#' Get sorted vector of completed rep_ids for one setting tag.
checkpoint_completed_reps <- function(outdir, tag) {
  path <- file.path(outdir, paste0(tag, ".rds"))
  if (!file.exists(path)) return(integer(0))
  tryCatch({
    x <- readRDS(path)
    if (is.data.frame(x) && "rep" %in% names(x) && nrow(x) > 0)
      sort(unique(as.integer(x$rep)))
    else
      integer(0)
  }, error = function(e) integer(0))
}

#' Load checkpoint data for one setting (NULL if absent/corrupt).
checkpoint_load <- function(outdir, tag) {
  path <- file.path(outdir, paste0(tag, ".rds"))
  if (!file.exists(path)) return(NULL)
  tryCatch({
    x <- readRDS(path)
    if (is.data.frame(x) && nrow(x) > 0 && "rep" %in% names(x)) x
    else NULL
  }, error = function(e) NULL)
}

#' Save new results, merging with existing and deduplicating by rep_id.
#' Returns the combined data.frame (invisibly).
checkpoint_save <- function(outdir, tag, new_data) {
  path <- file.path(outdir, paste0(tag, ".rds"))
  existing <- checkpoint_load(outdir, tag)
  if (!is.null(existing)) {
    # Safety: remove any pre-existing rows for rep_ids we're about to add
    overlap <- existing$rep %in% new_data$rep
    if (any(overlap)) {
      n_dup <- length(unique(existing$rep[overlap]))
      warning(sprintf("%s: deduplicating %d overlapping rep_id(s)", tag, n_dup))
      existing <- existing[!overlap, ]
    }
    combined <- rbind(existing, new_data)
  } else {
    combined <- new_data
  }
  saveRDS(combined, path)
  invisible(combined)
}

#' Write checkpoint metadata after successful completion.
checkpoint_save_meta <- function(outdir, config_sig, n_rep) {
  saveRDS(list(
    config    = config_sig$config,
    md5       = config_sig$md5,
    n_rep     = n_rep,
    timestamp = Sys.time()
  ), file.path(outdir, "_meta.rds"))
}

# ============================================================
# Simulation functions
# ============================================================

get_lasso_residual <- function(X, y, seed = 42) {
  set.seed(seed)
  n <- nrow(X)
  cv_fit    <- cv.glmnet(X, y, alpha = 1, nfolds = min(10, n))
  lasso_fit <- glmnet(X, y, alpha = 1, lambda = cv_fit$lambda.1se)
  as.vector(y - predict(lasso_fit, X))
}

gen_signflip_noise <- function(U, resid, target_sd, seed = 1) {
  set.seed(seed)
  r_centered <- resid - mean(resid)
  proj  <- as.vector(crossprod(U, r_centered))
  k     <- length(proj)
  signs <- sample(c(-1, 1), k, replace = TRUE)
  noise <- as.vector(U[, 1:k] %*% (proj * signs))
  noise * target_sd / sd(noise)
}

run_one_rep <- function(rep_i, n_causal, h2_sp, X, y_real,
                        U_pre = NULL, resid_pre = NULL, L = 10,
                        noise_scale = noise_scale_factor) {
  suppressWarnings({
    n <- nrow(X); p <- ncol(X)

    if (!is.null(U_pre) && !is.null(resid_pre)) {
      U     <- U_pre
      resid <- resid_pre
    } else {
      seed_sub <- make_seed(rep_i, n, "subsample")
      resid    <- get_lasso_residual(X, y_real, seed = seed_sub)
      Xs       <- scale(X, center = TRUE, scale = FALSE)
      svd_X    <- svd(Xs, nu = n, nv = 0)
      U        <- svd_X$u
    }

    # Sign-flip noise, scaled to match real data's noise level
    target_noise_sd <- sd(y_real) * noise_scale
    seed_flip <- make_seed(rep_i, n, "flip", h2 = h2_sp, nc = n_causal)
    noise <- gen_signflip_noise(U, resid, target_noise_sd, seed = seed_flip)

    # Causal signal: h2 = var(signal) / var(y)
    seed_causal <- make_seed(rep_i, n, "causal", h2 = h2_sp, nc = n_causal)
    set.seed(seed_causal)
    causal_idx <- sample(p, n_causal)
    y <- noise
    for (j in causal_idx) {
      bj <- sqrt(h2_sp * var(noise) / ((1 - h2_sp) * n_causal * var(X[, j])))
      y  <- y + X[, j] * bj
    }

    # Fit both methods
    fit_gaus <- tryCatch(
      susie(X, y, L = L, verbose = FALSE),
      error = function(e) NULL)
    fit_ss <- tryCatch(
      susie(X, y, L = L, estimate_residual_method = "Servin_Stephens",
            verbose = FALSE),
      error = function(e) NULL)

    extract <- function(fit, tag) {
      na_row <- data.frame(
        method = tag, rep = rep_i,
        discovered = NA_real_, n_true_cs = NA_real_, n_cs = NA_real_,
        mean_size = NA_real_, sigma2 = NA_real_,
        mean_V = NA_real_, max_V = NA_real_, sum_V = NA_real_,
        stringsAsFactors = FALSE)
      if (is.null(fit)) return(na_row)

      cs_obj <- susie_get_cs(fit, X = X, min_abs_corr = 0.5)
      cs     <- cs_obj$cs
      ncs    <- length(cs)

      discovered <- 0; n_true_cs <- 0; avg_size <- NA_real_
      if (ncs > 0) {
        discovered <- length(intersect(unique(unlist(cs)), causal_idx))
        n_true_cs  <- sum(sapply(cs, function(s) any(causal_idx %in% s)))
        avg_size   <- mean(sapply(cs, length))
      }

      V_vec <- fit$V
      if (is.null(V_vec)) V_vec <- rep(NA_real_, L)
      if (length(V_vec) == 1) V_vec <- rep(V_vec, L)

      data.frame(
        method     = tag,
        rep        = rep_i,
        discovered = discovered,
        n_true_cs  = n_true_cs,
        n_cs       = ncs,
        mean_size  = avg_size,
        sigma2     = fit$sigma2,
        mean_V     = mean(V_vec, na.rm = TRUE),
        max_V      = max(V_vec, na.rm = TRUE),
        sum_V      = sum(V_vec, na.rm = TRUE),
        stringsAsFactors = FALSE)
    }

    rbind(extract(fit_gaus, "Gaussian"), extract(fit_ss, "SS"))
  })
}

cat("Functions defined:\n")
cat("  Seeds:       make_seed(rep_i, N, purpose, h2, nc)\n")
cat("  Checkpoint:  compute_config_md5, checkpoint_init, checkpoint_completed_reps,\n")
cat("               checkpoint_load, checkpoint_save, checkpoint_save_meta\n")
cat("  Simulation:  get_lasso_residual, gen_signflip_noise, run_one_rep\n")

## Run simulation

Per-setting results are checkpointed as individual `.rds` files.
An MD5 hash of the full configuration (data fingerprint, noise calibration,
all design parameters) is stored in `_meta.rds` and verified before every run.

**Checkpoint behavior:**
- **MD5 matches** → existing results are loaded; only new `rep_id`s are computed.
  Changing `n_rep` from 10 to 200 runs only reps 11–200.
- **MD5 mismatch** → all old results are automatically deleted and the
  simulation starts fresh. This triggers whenever the data, noise calibration,
  sample sizes, heritabilities, or causal counts change.
- `checkpoint_save()` deduplicates by `rep_id` as a safety net against
  accidental re-runs.

**Seed management:**
- Each random operation (subsampling, sign-flip, causal placement) uses a
  deterministic seed derived from `(rep_id, N, h2, n_causal)` via `make_seed()`.
- Same `rep_id` always produces the same result (reproducibility).
- Different `rep_id`s always produce different draws (no collisions).
- The checkpoint tracks which `rep_id`s are complete and never re-runs them.

In [None]:
outdir  <- "benchmark_results"
all_rds <- file.path(outdir, "all_results.rds")

# --- Compute and verify config signature ---
config_sig <- compute_config_md5(X_full, y_full, N_vals, h2_sparse,
                                 L_causal, L, noise_scale_factor)
init <- checkpoint_init(outdir, config_sig)

# --- Main simulation (incremental) ---
all_results <- list()
t_total     <- proc.time()
n_full      <- nrow(X_full)
ns          <- noise_scale_factor
any_new     <- FALSE

for (N in N_vals) {
  # ── Which rep_ids does ANY setting at this N still need? ──
  # We precompute LASSO+SVD only for these, saving time when extending.
  needed_reps <- integer(0)
  for (h2 in h2_sparse) {
    for (nc in L_causal) {
      tag  <- sprintf("N%d_h2%03d_nc%d", N, round(h2 * 100), nc)
      done <- checkpoint_completed_reps(outdir, tag)
      todo <- setdiff(seq_len(n_rep), done)
      needed_reps <- union(needed_reps, todo)
    }
  }
  needed_reps <- sort(needed_reps)

  if (length(needed_reps) == 0) {
    # Everything cached — just load
    cat(sprintf("\n=== N = %d: all settings complete, loading ===\n", N))
    for (h2 in h2_sparse) {
      for (nc in L_causal) {
        tag <- sprintf("N%d_h2%03d_nc%d", N, round(h2 * 100), nc)
        res <- checkpoint_load(outdir, tag)
        # Trim to exactly n_rep reps
        done_ids <- sort(unique(res$rep))
        if (length(done_ids) > n_rep)
          res <- res[res$rep %in% done_ids[1:n_rep], ]
        all_results[[length(all_results) + 1]] <- res
      }
    }
    next
  }

  # ── Precompute LASSO + SVD only for needed reps ──
  cat(sprintf("\n=== N = %d: precomputing %d / %d subsamples ===\n",
      N, length(needed_reps), n_rep))
  t_pre <- proc.time()

  keep_list  <- vector("list", n_rep)
  U_list     <- vector("list", n_rep)
  resid_list <- vector("list", n_rep)

  for (i in needed_reps) {
    seed_i <- make_seed(i, N, "subsample")
    set.seed(seed_i)
    keep <- sample(n_full, N)
    keep_list[[i]] <- keep
    Xi <- X_full[keep, ]
    yi <- y_full[keep]
    resid_list[[i]] <- get_lasso_residual(Xi, yi, seed = seed_i)
    Xs <- scale(Xi, center = TRUE, scale = FALSE)
    svd_i <- svd(Xs, nu = N, nv = 0)
    U_list[[i]] <- svd_i$u
  }
  cat(sprintf("  Precompute: %.0f sec\n", (proc.time() - t_pre)[3]))

  # ── Run or extend each setting ──
  for (h2 in h2_sparse) {
    for (nc in L_causal) {
      tag       <- sprintf("N%d_h2%03d_nc%d", N, round(h2 * 100), nc)
      done_reps <- checkpoint_completed_reps(outdir, tag)
      todo_reps <- sort(setdiff(seq_len(n_rep), done_reps))

      if (length(todo_reps) == 0) {
        # Already complete — load
        cat(sprintf("[DONE] %s (%d reps)\n", tag, length(done_reps)))
        res <- checkpoint_load(outdir, tag)
        done_ids <- sort(unique(res$rep))
        if (length(done_ids) > n_rep)
          res <- res[res$rep %in% done_ids[1:n_rep], ]
        all_results[[length(all_results) + 1]] <- res
        next
      }

      if (length(done_reps) > 0) {
        cat(sprintf("[EXT]  %s: %d -> %d reps ... ",
            tag, length(done_reps), length(done_reps) + length(todo_reps)))
      } else {
        cat(sprintf("[RUN]  %s (%d reps) ... ", tag, length(todo_reps)))
      }
      t0 <- proc.time()
      any_new <- TRUE

      new_res <- do.call(rbind, future_lapply(todo_reps, function(i) {
        run_one_rep(i, nc, h2,
                    X_full[keep_list[[i]], ], y_full[keep_list[[i]]],
                    U_pre = U_list[[i]], resid_pre = resid_list[[i]],
                    L = L, noise_scale = ns)
      }, future.seed = TRUE))
      new_res$N         <- N
      new_res$h2_sparse <- h2
      new_res$n_causal  <- nc

      # Save with dedup, then collect
      res <- checkpoint_save(outdir, tag, new_res)
      all_results[[length(all_results) + 1]] <- res
      cat(sprintf("done (%.0f sec)\n", (proc.time() - t0)[3]))
    }
  }
}

# --- Finalize ---
results <- do.call(rbind, all_results)
saveRDS(results, all_rds)
checkpoint_save_meta(outdir, config_sig, n_rep)

total_settings <- length(N_vals) * length(h2_sparse) * length(L_causal)
if (any_new) {
  cat(sprintf("\nCompleted: %d rows across %d settings, %.1f min\n",
      nrow(results), total_settings, (proc.time() - t_total)[3] / 60))
} else {
  cat(sprintf("\nAll %d settings x %d reps loaded from cache (%d rows)\n",
      total_settings, n_rep, nrow(results)))
}
cat(sprintf("Config MD5: %s\nSaved: %s\n", config_sig$md5, all_rds))

## Aggregate results

In [5]:
# Ensure numeric types
for (col in c("discovered", "n_true_cs", "n_cs", "mean_size",
              "sigma2", "mean_V", "max_V", "sum_V",
              "N", "h2_sparse", "n_causal")) {
  if (col %in% names(results))
    results[[col]] <- as.numeric(results[[col]])
}

# Aggregate across replicates
groups   <- unique(results[, c("method", "N", "h2_sparse", "n_causal")])
agg_list <- vector("list", nrow(groups))

for (gi in seq_len(nrow(groups))) {
  m  <- groups$method[gi]
  nn <- groups$N[gi]
  h2 <- groups$h2_sparse[gi]
  nc <- groups$n_causal[gi]
  df <- results[results$method == m & results$N == nn &
                results$h2_sparse == h2 & results$n_causal == nc, ]
  if (nrow(df) == 0) next

  nr           <- nrow(df)
  total_causal <- nc * nr
  s  <- function(x) sum(x, na.rm = TRUE)
  mn <- function(x) mean(x, na.rm = TRUE)

  agg_list[[gi]] <- data.frame(
    method      = m,
    N           = nn,
    h2_sparse   = h2,
    n_causal    = nc,
    power       = s(df$discovered) / total_causal,
    coverage    = ifelse(s(df$n_cs) > 0, s(df$n_true_cs) / s(df$n_cs), NA),
    total_cs    = s(df$n_cs),
    cs_size     = mn(df$mean_size),
    mean_sigma2 = mn(df$sigma2),
    mean_V      = mn(df$mean_V),
    max_V       = mn(df$max_V),
    sum_V       = mn(df$sum_V),
    stringsAsFactors = FALSE)
}

agg <- do.call(rbind, agg_list)
rownames(agg) <- NULL

# Factor columns for plotting
agg$BF        <- factor(agg$method, levels = c("SS", "Gaussian"))
agg$L         <- agg$n_causal
agg$n         <- agg$N
agg$cs_per_rep <- agg$total_cs / n_rep

cat(sprintf("Aggregated: %d rows\n\n", nrow(agg)))

# Print summary table
cat(sprintf("%-10s %3s %5s %3s  %6s %6s %6s %6s %8s %8s\n",
    "method", "N", "h2", "nc", "power", "cover", "cs/rep", "size", "sigma2", "sum_V"))
cat(paste(rep("-", 82), collapse = ""), "\n")
for (i in seq_len(nrow(agg))) {
  a <- agg[i, ]
  cat(sprintf("%-10s %3d %5.2f %3d  %6.3f %6.3f %6.2f %6.1f %8.4f %8.4f\n",
      a$method, a$N, a$h2_sparse, a$n_causal,
      a$power,
      ifelse(is.na(a$coverage), 0, a$coverage),
      a$cs_per_rep,
      ifelse(is.na(a$cs_size), 0, a$cs_size),
      a$mean_sigma2,
      a$sum_V))
}

Aggregated: 150 rows



method       N    h2  nc   power  cover cs/rep   size   sigma2    sum_V


---------------------------------------------------------------------------------- 


Gaussian    30  0.25   1   0.000  0.000   0.00    0.0   0.3548   0.0492
SS          30  0.25   1   0.000  0.000   0.00    0.0   0.9885   0.0144
Gaussian    30  0.25   2   0.000  0.000   0.00    0.0   0.3524   0.0804
SS          30  0.25   2   0.000  0.000   0.00    0.0   0.9383   0.0618
Gaussian    30  0.25   3   0.000  0.000   0.10  114.0   0.3595   0.0652
SS          30  0.25   3   0.000  0.000   0.10    1.0   0.9507   0.0629
Gaussian    30  0.25   4   0.025  1.000   0.10   38.0   0.3594   0.0737
SS          30  0.25   4   0.050  1.000   0.10   51.0   0.9465   0.0588
Gaussian    30  0.25   5   0.000  0.000   0.00    0.0   0.3946   0.0461
SS          30  0.25   5   0.000  0.000   0.00    0.0   0.9937   0.0209
Gaussian    30  0.50   1   0.500  1.000   0.50   16.4   0.3096   0.3436
SS          30  0.50   1   0.400  1.000   0.40   17.0   0.9013   0.3583
Gaussian    30  0.50   2   0.050  0.500   0.20   18.5   0.4310   0.1511
SS          30  0.50   2   0.050  1.000   0.10   44.0   0.9558  

In [6]:
library(ggplot2)
library(cowplot)
library(gridExtra)
library(grid)

figdir <- "benchmark_results"

methods_colors <- c("SS" = "#D41159", "Gaussian" = "#1A85FF")

perf_theme <- theme_cowplot(font_size = 16) +
  theme(
    legend.position  = "none",
    panel.grid.major.y = element_line(color = "gray80"),
    panel.grid.major.x = element_blank(),
    panel.grid.minor   = element_blank(),
    axis.line   = element_line(linewidth = 1, color = "black"),
    axis.ticks  = element_line(linewidth = 1, color = "black"),
    axis.ticks.length = unit(0.25, "cm"),
    plot.margin = margin(t = 2, r = 2, b = 2, l = 2, unit = "mm"),
    axis.text   = element_text(size = 14, face = "bold"),
    axis.title  = element_text(size = 16, face = "bold"),
    plot.title  = element_text(size = 16, face = "bold")
  )
dot_size <- 4
cat("Plot theme set.\n")

Plot theme set.


In [7]:
# --- Replace NA/NaN with 0 so every dot is plotted ---
agg$coverage_plot  <- ifelse(is.na(agg$coverage), 0, agg$coverage)
agg$cs_size_plot   <- ifelse(is.na(agg$cs_size) | is.nan(agg$cs_size), 0, agg$cs_size)

# Dodge width: horizontally separate the two methods
dodge <- position_dodge(width = 0.4)

for (h2 in h2_sparse) {
  d <- agg[agg$h2_sparse == h2, ]

  plots <- list()
  for (ni in seq_along(N_vals)) {
    nn <- N_vals[ni]
    dd <- d[d$N == nn, ]
    ylab_fn <- function(lab) if (ni == 1) lab else ""

    # ── Coverage (use coverage_plot: NA → 0) ──
    p_cov <- ggplot(dd, aes(x = as.factor(L), y = coverage_plot, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      geom_hline(yintercept = 0.95, linetype = "dashed", linewidth = 1) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.02, 1.02)) +
      ylab(ylab_fn("Coverage")) + xlab("") + perf_theme
    plots[[paste0("cov_", nn)]] <- p_cov

    # ── Power (linear scale, 0 shown at bottom) ──
    p_pow <- ggplot(dd, aes(x = as.factor(L), y = power, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.02, 1.02)) +
      ylab(ylab_fn("Power")) + xlab("") + perf_theme
    plots[[paste0("pow_", nn)]] <- p_pow

    # ── CS per replicate ──
    p_ncs <- ggplot(dd, aes(x = as.factor(L), y = cs_per_rep, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.05, max(agg$cs_per_rep, na.rm = TRUE) * 1.05)) +
      ylab(ylab_fn("CS / replicate")) + xlab("") + perf_theme
    plots[[paste0("ncs_", nn)]] <- p_ncs

    # ── CS size (use cs_size_plot: NaN → 0) ──
    cs_max <- max(agg$cs_size_plot, na.rm = TRUE)
    p_size <- ggplot(dd, aes(x = as.factor(L), y = cs_size_plot, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.5, cs_max * 1.05)) +
      ylab(ylab_fn("CS size")) + xlab("") + perf_theme
    plots[[paste0("size_", nn)]] <- p_size

    # ── sum V (prior variance diagnostic) ──
    p_sumv <- ggplot(dd, aes(x = as.factor(L), y = sum_V, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.01, max(agg$sum_V, na.rm = TRUE) * 1.05)) +
      ylab(ylab_fn(expression(Sigma~V))) + xlab("") + perf_theme
    plots[[paste0("sumv_", nn)]] <- p_sumv

    # ── sigma2 ──
    p_sig <- ggplot(dd, aes(x = as.factor(L), y = mean_sigma2, col = BF)) +
      geom_point(size = dot_size, position = dodge) +
      scale_color_manual(values = methods_colors) +
      coord_cartesian(ylim = c(-0.01, max(agg$mean_sigma2, na.rm = TRUE) * 1.1)) +
      ylab(ylab_fn(expression(hat(sigma)^2))) +
      xlab("Number of causal variants") + perf_theme
    plots[[paste0("sig_", nn)]] <- p_sig
  }

  # Column headers
  titles <- lapply(N_vals, function(nn)
    textGrob(label = paste0("N = ", nn),
             gp = gpar(fontsize = 16, fontface = "bold")))

  metric_rows <- c("cov", "pow", "ncs", "size", "sumv", "sig")
  plot_grobs <- lapply(metric_rows, function(r)
    lapply(N_vals, function(nn) plots[[paste0(r, "_", nn)]]))
  plot_grobs <- do.call(c, plot_grobs)

  n_cols <- length(N_vals)
  fig <- arrangeGrob(
    arrangeGrob(grobs = titles, ncol = n_cols),
    arrangeGrob(grobs = plot_grobs, ncol = n_cols, nrow = 6),
    heights = c(0.04, 1),
    top = textGrob(sprintf("h2_sparse = %d%%", round(h2 * 100)),
                   gp = gpar(fontsize = 18, fontface = "bold")))

  fn <- sprintf("benchmark_h2%03d", round(h2 * 100))
  pdf(file.path(figdir, paste0(fn, ".pdf")), width = 26, height = 20)
  grid.draw(fig); dev.off()
  png(file.path(figdir, paste0(fn, ".png")), width = 26, height = 20,
      units = "in", res = 150)
  grid.draw(fig); dev.off()
  cat(sprintf("Saved: %s.{pdf,png}\n", fn))
}

Saved: benchmark_h2025.{pdf,png}
Saved: benchmark_h2050.{pdf,png}
Saved: benchmark_h2075.{pdf,png}


In [8]:
# --- Standalone legend ---
legend_df <- data.frame(
  x = c(1, 1), y = c(1, 2),
  grp = factor(c("SS (Servin-Stephens)", "Gaussian (standard)"),
               levels = c("SS (Servin-Stephens)", "Gaussian (standard)")))
legend_colors <- c("SS (Servin-Stephens)" = "#D41159",
                   "Gaussian (standard)"  = "#1A85FF")
p_leg <- ggplot(legend_df, aes(x, y, col = grp)) +
  geom_point(size = 5) +
  scale_color_manual(values = legend_colors, name = "Method") +
  theme_void() +
  theme(legend.position  = "bottom",
        legend.text  = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"))
legend_grob <- cowplot::get_legend(p_leg)

pdf(file.path(figdir, "benchmark_legend.pdf"), width = 8, height = 1)
grid::grid.draw(legend_grob); dev.off()
png(file.path(figdir, "benchmark_legend.png"), width = 8, height = 1,
    units = "in", res = 150)
grid::grid.draw(legend_grob); dev.off()
cat("Saved: benchmark_legend.{pdf,png}\n")

Saved: benchmark_legend.{pdf,png}


## Results

*Results will be summarized here after reviewing the simulation output.*

## How to run

From the `inst/notebooks/` directory:

```bash
jupyter nbconvert --to notebook --execute \
  --ExecutePreprocessor.timeout=0 \
  --output small_sample_benchmark_executed.ipynb \
  small_sample_benchmark.ipynb
```

The `--ExecutePreprocessor.timeout=0` flag disables the cell timeout so the
simulation can run as long as needed. With 60 settings x 200 replicates and
10 parallel workers, expect roughly 4-8 hours on a modern machine.

The simulation is **incremental**: existing results in `benchmark_results/`
are preserved and only new replicates are computed. To extend from 200 to 400
replicates, change `n_rep` in the config cell and re-run. To start fresh
(e.g., after switching to a different dataset), either delete
`benchmark_results/_meta.rds` or simply change the data file; the config
signature check will detect the mismatch and clear old results automatically.