Diversity by Design: Addressing Mode Collapse Improves scRNA-seq Perturbation Modeling on Well-Calibrated Metrics
Gabriel Mejía1*, Henry E. Miller1*, Francis J. A. Leblanc1, Bo Wang2, Brendan Swain1, Lucas Paulo de Lima Camillo1
*Equal contribution.
1 Shift Bioscience, Cambridge, UK.
2 University of Toronto, Vector Institute, Toronto, Canada.
- ArXiv Preprint here
Recent benchmarks reveal that models for single-cell perturbation response are often outperformed by simply predicting the dataset mean. We trace this anomaly to a metric artifact: control-referenced deltas and unweighted error metrics reward mode collapse whenever the control is biased or the biological signal is sparse. Large-scale in silico simulations and analysis of two real-world perturbation datasets confirm that shared reference shifts, not genuine biological change, drives high performance in these evaluations. We introduce differentially expressed gene (DEG)–aware metrics, weighted mean-squared error (WMSE) and weighted delta
Install uv to manage the dependencies
curl -LsSf https://astral.sh/uv/install.sh | shInstall the dependencies
uv sync- Get the data:
uv run data/norman19/get_data.py # Will take a few minutes
uv run data/replogle22/get_data.py # Will take a few minutes- Run synthetic data simulations:
uv run analyses/synthetic_simulations/parameter_estimation.py
uv run analyses/synthetic_simulations/random_sweep.pyPlots will be stored in analyses/synthetic_simulations/paper_plots.
- Run simulations on real datasets:
Dataset can be norman19 or replogle22.
cd analyses/real_data_simulations/
uv run simulation.py --dataset norman19
uv run simulation.py --dataset replogle22Figures/results are in analyses/real_data_simulations/<dataset>/
- Run the niche signal sensitivity analysis:
cd analyses/sensitivity_to_niche_signals/
uv run sensitivity_analysis.py --dataset norman19
uv run sensitivity_analysis.py --dataset replogle22Figures/results are in analyses/sensitivity_to_niche_signals/<dataset>/
- Train GEARS +/- WMSE loss & analyze the output:
cd analyses/modeling_metrics/
uv run GEARS_norman19.py # Include --multiprocessing if you have 6 GPUs available locally
uv run GEARS_replogle22.py # Include --multiprocessing if you have 6 GPUs available locally
uv run plotting.py --dataset norman19
uv run plotting.py --dataset replogle22Figures/results are in analyses/modeling_metrics/<dataset>/.
Note: GEARS training only with MSE is very unstable so repeated runs may show numerical differences. WMSE actually increases the stability of training results.
