Skip to content

shiftbioscience/diversity_by_design

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diversity by Design: Addressing Mode Collapse Improves scRNA-seq Perturbation Modeling on Well-Calibrated Metrics

Gabriel Mejía1*, Henry E. Miller1*, Francis J. A. Leblanc1, Bo Wang2, Brendan Swain1, Lucas Paulo de Lima Camillo1


*Equal contribution.
1 Shift Bioscience, Cambridge, UK.
2 University of Toronto, Vector Institute, Toronto, Canada.
  • ArXiv Preprint here

Abstract

Recent benchmarks reveal that models for single-cell perturbation response are often outperformed by simply predicting the dataset mean. We trace this anomaly to a metric artifact: control-referenced deltas and unweighted error metrics reward mode collapse whenever the control is biased or the biological signal is sparse. Large-scale in silico simulations and analysis of two real-world perturbation datasets confirm that shared reference shifts, not genuine biological change, drives high performance in these evaluations. We introduce differentially expressed gene (DEG)–aware metrics, weighted mean-squared error (WMSE) and weighted delta $R^{2}~~(R^{2}_{w}(\Delta))$ with respect to all perturbations, that measure error in niche signals with high sensitivity. We further introduce negative and positive performance baselines to calibrate these metrics. With these improvements, the mean baseline sinks to null performance while genuine predictors are correctly rewarded. Finally, we show that using WMSE as a loss function reduces mode collapse and improves model performance.

Graphical_abstrac

Getting started

Install uv to manage the dependencies

curl -LsSf https://astral.sh/uv/install.sh | sh

Install the dependencies

uv sync

Workflow to run analyses from paper

  1. Get the data:
uv run data/norman19/get_data.py # Will take a few minutes
uv run data/replogle22/get_data.py # Will take a few minutes
  1. Run synthetic data simulations:
uv run analyses/synthetic_simulations/parameter_estimation.py
uv run analyses/synthetic_simulations/random_sweep.py

Plots will be stored in analyses/synthetic_simulations/paper_plots.

  1. Run simulations on real datasets:

Dataset can be norman19 or replogle22.

cd analyses/real_data_simulations/
uv run simulation.py --dataset norman19
uv run simulation.py --dataset replogle22

Figures/results are in analyses/real_data_simulations/<dataset>/

  1. Run the niche signal sensitivity analysis:
cd analyses/sensitivity_to_niche_signals/
uv run sensitivity_analysis.py --dataset norman19
uv run sensitivity_analysis.py --dataset replogle22

Figures/results are in analyses/sensitivity_to_niche_signals/<dataset>/

  1. Train GEARS +/- WMSE loss & analyze the output:
cd analyses/modeling_metrics/
uv run GEARS_norman19.py # Include --multiprocessing if you have 6 GPUs available locally
uv run GEARS_replogle22.py # Include --multiprocessing if you have 6 GPUs available locally
uv run plotting.py --dataset norman19
uv run plotting.py --dataset replogle22

Figures/results are in analyses/modeling_metrics/<dataset>/.

Note: GEARS training only with MSE is very unstable so repeated runs may show numerical differences. WMSE actually increases the stability of training results.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages