Skip to content

yinuoren/DriftLite

Repository files navigation

DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models (ICLR 2026)

ICLR 2026 arXiv OpenReview BibTeX License: MIT X post

Research code accompanying the ICLR 2026 paper DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models.

TL;DR: DriftLite is a lightweight, training-free, particle-based method for inference-time scaling, i.e., adapting a pre-trained diffusion model to a new target distribution without retraining. It exploits a fundamental degree of freedom in the Feynman–Kac-type Fokker–Planck equation to compute an on-the-fly control drift that stabilizes particle weights and reduces variance, with minimal computational overhead.

Paper abstract
We study inference-time scaling for diffusion models, where the goal is to adapt a pre-trained model to new target distributions without retraining. Existing guidance-based methods are simple but introduce bias, while particle-based corrections suffer from weight degeneracy and high computational cost. We introduce DriftLite, a lightweight, training-free particle-based approach that steers the inference dynamics on the fly with provably optimal stability control. DriftLite exploits a fundamental degree of freedom in the Fokker-Planck equation between the drift and particle potential, and yields two practical instantiations: Variance- and Energy-Controlling Guidance (VCG/ECG) for approximating the optimal drift with modest and scalable overhead. Across Gaussian mixture models, particle systems, and large-scale protein-ligand co-folding problems, DriftLite consistently reduces variance and improves sample quality over pure guidance and sequential Monte Carlo baselines. These results highlight a principled, efficient route toward scalable inference-time adaptation of diffusion models.

What’s in this repo

  • run_unified.py: main entry point for unified experiments (sampling + printed evaluation tables) for GMM / Lennard–Jones / Double-Well, with different control strategies (Pure Guidance (G), Variance-Controlling Guidance (VCG), Energy-Controlling Guidance (ECG), and their SMC variants).
  • train_score_network.py: score network training loop (writes to checkpoints/; used by LJ/DW; optional for GMM).
  • examples/: problem setups (GMM, Lennard-Jones, Double Well).
  • methods/: drift-control strategies (G / VCG / ECG; plus an experimental neural control guidance).
  • models/: score network architectures and EDM helpers.
  • schedules.py, simulation.py: schedules and simulation routines.
  • checkpoints/<example_type>/score_<TIME>/: saved training configs and parameters (generated by training; not committed).
  • logs/<example_type>/<RUN_ID>/: run outputs created by run_unified.py (per-run args.json, optional *_results.npz + *_results_meta.json when --save_results is used).

Setup

Create a conda environment and install dependencies:

conda env create -f environment.yml
conda activate jax

Notes:

  • This environment is pinned to Python 3.12 + JAX 0.6.x and includes CUDA12 JAX wheels (jax-cuda12-*) in environment.yml.
  • If you want a CPU-only environment, remove the jax-cuda12-* pip entries and install a CPU build of JAX instead (see the JAX install guide).

Quickstart

0) Recommended: open the interactive notebook

If you want to start exploring quickly, please refer to the interactive notebook.

1) Run a small GMM experiment (no training required)

This is the fastest “it works” path because GMM has an analytical score.

python run_unified.py \
  --example_type gmm \
  --mode annealing \
  --gamma 1.0 \
  --n_steps 500 \
  --n_particles 4096 \
  --n_samples 8192 \
  --use_guidance \
  --use_variance_control_guidance \
  --use_smc

You should see a printed metrics table (NLL diff / MMD / SWD) and a new run folder under logs/gmm/<RUN_ID>/ containing args.json.

2) Run the provided unified script

Run the unified experiment script:

./run_unified.sh

Notes:

  • Outputs are written under logs/<example_type>/<RUN_ID>/. If you run under Slurm, the folder name defaults to SLURM_JOB_ID; otherwise it uses the printed unique_run_id.

3) Train a score network

Train a score network (example configuration in run_training.sh; note it is configured for very long runs by default):

./run_training.sh

Training writes to checkpoints/<example_type>/score_<TIME>/ and stores the model path in config.json (field: network.checkpoint_dir). You can resume training with:

python train_score_network.py --example_type lj --resume_from_time <TIME>

4) Lennard–Jones / Double-Well runs (require a trained score checkpoint)

For lj and dw, run_unified.py typically uses a trained score network. After training, you can use the --config_path_time flag to load the checkpoint.

Example LJ command using a trained checkpoint:

python run_unified.py \
  --example_type lj \
  --config_path_time <TIME> \
  --mode annealing \
  --gamma 2.0 \
  --n_steps 1000 \
  --n_particles 32768 \
  --n_samples 32768 \
  --eval_batch_size 32768 \
  --use_guidance \
  --use_variance_control_guidance \
  --use_smc \
  --resampling_mode ess \
  --ess_threshold 0.9 \
  --use_trained_score \
  --save_results

For an interactive exploration of the GMM setup and variance-controlling guidance, see gmm_minimal.ipynb.

Citation

If you use this code, please cite:

@inproceedings{ren2026driftlite,
    title={DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models},
    author={Ren, Yinuo and Gao, Wenhao and Ying, Lexing and Rotskoff, Grant M. and Han, Jiequn},
    booktitle={The Fourteenth International Conference on Learning Representations},
    year={2026},
    url={https://openreview.net/forum?id=l01eG3Qikl}
}

License

This project is released under the MIT License (see LICENSE).

About

Research codebase accompanying the ICLR 2026 paper DriftLite: Lightweight Drift Control for Inference-Time Scaling of Diffusion Models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors