JAX implementation for the ICML 2026 paper Deep Coupling Learning for Solving PDEs.
This repository contains the code needed to run the experiments from Sections 5.1 and 5.2.
Create or activate a Python environment with JAX installed, then install this package:
pip install -e .For GPU runs, install the JAX/JAXLIB build matching your CUDA version before installing this package.
Due to numerical precision and hardware differences, reproduced values may differ slightly from those reported in the paper; such small variations do not affect the experimental conclusions.
couplednet/ Core model, PDE, sampler, and training utilities
experiments/ Experiment entry points
experiments/configs/ Reproducible experiment configs
experiments/records/ Search records retained for traceability
Run a CoupledNet depth config:
python -m experiments.section_5_1_high_frequency \
--config experiments/configs/section_5_1_couplednet/couplednet_8L_best.jsonOther CoupledNet depth configs are in:
experiments/configs/section_5_1_couplednet/
Run the Appendix baseline architectures:
python -m experiments.section_5_1_high_frequency \
--config experiments/configs/section_5_1_appendix_baselines/appendix_baselines_paper.jsonRun all Section 5.2 configs:
python -m experiments.section_5_2_high_dynamic_range \
--config experiments/configs/section_5_2_default.jsonRun one Section 5.2 config by name:
python -m experiments.section_5_2_high_dynamic_range \
--config experiments/configs/section_5_2_default.json \
--run couplednet_8L_lr1e-3Available run names are listed in:
experiments/configs/section_5_2_default.json
Experiment outputs are written to results/ by default. Use --output-dir to choose another directory:
python -m experiments.section_5_1_high_frequency \
--config experiments/configs/section_5_1_couplednet/couplednet_8L_best.json \
--output-dir results/section_5_1_8LSeveral additional benchmark experiments in the paper follow the benchmark structure from the JAX-PI project. The corresponding paper configs can be tested in a JAX-PI-style benchmark setup. We thank the JAX-PI authors for making their benchmark code available.
If you use this code, please cite:
@inproceedings{meng2026deep,
title={Deep Coupling Learning for Solving PDEs},
author={Meng, Lingshi and Shi, Haosen and Pan, Sinno Jialin},
booktitle={Proceedings of the 43rd International Conference on Machine Learning},
year={2026}
}This code is released under the MIT License.