Research code accompanying the ICLR 2026 paper DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models.
TL;DR: DriftLite is a lightweight, training-free, particle-based method for inference-time scaling, i.e., adapting a pre-trained diffusion model to a new target distribution without retraining. It exploits a fundamental degree of freedom in the Feynman–Kac-type Fokker–Planck equation to compute an on-the-fly control drift that stabilizes particle weights and reduces variance, with minimal computational overhead.
Paper abstract
We study inference-time scaling for diffusion models, where the goal is to adapt a pre-trained model to new target distributions without retraining. Existing guidance-based methods are simple but introduce bias, while particle-based corrections suffer from weight degeneracy and high computational cost. We introduce DriftLite, a lightweight, training-free particle-based approach that steers the inference dynamics on the fly with provably optimal stability control. DriftLite exploits a fundamental degree of freedom in the Fokker-Planck equation between the drift and particle potential, and yields two practical instantiations: Variance- and Energy-Controlling Guidance (VCG/ECG) for approximating the optimal drift with modest and scalable overhead. Across Gaussian mixture models, particle systems, and large-scale protein-ligand co-folding problems, DriftLite consistently reduces variance and improves sample quality over pure guidance and sequential Monte Carlo baselines. These results highlight a principled, efficient route toward scalable inference-time adaptation of diffusion models.
run_unified.py: main entry point for unified experiments (sampling + printed evaluation tables) for GMM / Lennard–Jones / Double-Well, with different control strategies (Pure Guidance (G), Variance-Controlling Guidance (VCG), Energy-Controlling Guidance (ECG), and their SMC variants).train_score_network.py: score network training loop (writes tocheckpoints/; used by LJ/DW; optional for GMM).examples/: problem setups (GMM, Lennard-Jones, Double Well).methods/: drift-control strategies (G / VCG / ECG; plus an experimental neural control guidance).models/: score network architectures and EDM helpers.schedules.py,simulation.py: schedules and simulation routines.checkpoints/<example_type>/score_<TIME>/: saved training configs and parameters (generated by training; not committed).logs/<example_type>/<RUN_ID>/: run outputs created byrun_unified.py(per-runargs.json, optional*_results.npz+*_results_meta.jsonwhen--save_resultsis used).
Create a conda environment and install dependencies:
conda env create -f environment.yml
conda activate jaxNotes:
- This environment is pinned to Python 3.12 + JAX 0.6.x and includes CUDA12 JAX wheels (
jax-cuda12-*) inenvironment.yml. - If you want a CPU-only environment, remove the
jax-cuda12-*pip entries and install a CPU build of JAX instead (see the JAX install guide).
If you want to start exploring quickly, please refer to the interactive notebook.
This is the fastest “it works” path because GMM has an analytical score.
python run_unified.py \
--example_type gmm \
--mode annealing \
--gamma 1.0 \
--n_steps 500 \
--n_particles 4096 \
--n_samples 8192 \
--use_guidance \
--use_variance_control_guidance \
--use_smcYou should see a printed metrics table (NLL diff / MMD / SWD) and a new run folder under logs/gmm/<RUN_ID>/ containing args.json.
Run the unified experiment script:
./run_unified.shNotes:
- Outputs are written under
logs/<example_type>/<RUN_ID>/. If you run under Slurm, the folder name defaults toSLURM_JOB_ID; otherwise it uses the printedunique_run_id.
Train a score network (example configuration in run_training.sh; note it is configured for very long runs by default):
./run_training.shTraining writes to checkpoints/<example_type>/score_<TIME>/ and stores the model path in config.json (field: network.checkpoint_dir). You can resume training with:
python train_score_network.py --example_type lj --resume_from_time <TIME>For lj and dw, run_unified.py typically uses a trained score network. After training, you can use the --config_path_time flag to load the checkpoint.
Example LJ command using a trained checkpoint:
python run_unified.py \
--example_type lj \
--config_path_time <TIME> \
--mode annealing \
--gamma 2.0 \
--n_steps 1000 \
--n_particles 32768 \
--n_samples 32768 \
--eval_batch_size 32768 \
--use_guidance \
--use_variance_control_guidance \
--use_smc \
--resampling_mode ess \
--ess_threshold 0.9 \
--use_trained_score \
--save_resultsFor an interactive exploration of the GMM setup and variance-controlling guidance, see gmm_minimal.ipynb.
If you use this code, please cite:
@inproceedings{ren2026driftlite,
title={DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models},
author={Ren, Yinuo and Gao, Wenhao and Ying, Lexing and Rotskoff, Grant M. and Han, Jiequn},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=l01eG3Qikl}
}This project is released under the MIT License (see LICENSE).
