This repository provides efficient diagonal Fisher Information Matrix (FIM) estimators for classifier networks, as used in the paper:
Deterministic Bounds and Random Estimates of Metric Tensors on Neuromanifolds, K. Sun (2026), The Fourteenth International Conference on Learning Representations (ICLR 2026). OpenReview
This library focuses on diagonal FIM estimation for deep classifier networks. It is for:
- Optimization based on diagonal FIM
- Geometric analysis
- Quantization
It does not provide
- Full FIM computation
- Task-specific training pipelines
We currently support the following estimators of the diagonal of the FIM for classifier networks:
- EXACT — Exact diagonal computation (very expensive).
- EMPIRICAL — Empirical FIM.
- HUTCHINSON — Hutchinson's unbiased stochastic estimate.
- HUTCHINSON_SQRT — Alternative Hutchinson FIM implementation.
- HUTCHINSON_DG — Upper bound of Hutchinson FIM.
- HUTCHINSON_RANKK — Lower bound of Hutchinson FIM.
- HUTCHINSON_RANK1 — A special case of HUTCHINSON_RANKK with (k=1).
- Unbiased: the expectation gives the true FIM;
- Bounded Variance: its standard deviation is bounded by
$\sqrt{2}$ times the true Fisher information; - Efficient: the overhead involves one additional backward pass per batch, and thus is scalable to large neural networks
-
Compute a scalar
$h$ based on the neural network logits; -
Run
h.backward()to get ∇h; -
Form a low-rank representation of the FIM, resulting in a matrix
𝔽 = ∇h (∇h)ᵀ
-
Extract its diagonal elements $ \mathbb{F}_{ii} $, and accumulate into an internal buffer for each parameter of the neural network.
Note. Although we focus on diagonal FIM estimators in the current code, the Hutchinson FIM is a random low-rank representation of the full FIM and is not restricted to diagonal-only use.
To use the library
from fim import FIMType, diag_fim_step, get_diag_fim, zero_diag_fim
fim_type = FIMType.HUTCHINSON
fim_opts = { "num_probes": 1, "lr_power_itrs": 30, "top_k": 2 }
zero_diag_fim( model, fim_type=fim_type ) # reset internal diagonal-FIM buffers
for batch in loader:
...
diag_fim_step( model, batch, fim_type, fim_opts, ema=None, reduction="sum" ) # accumulate FIM
fim_buffers = get_diag_fim( model, fim_type=fim_type )To run batch experiments to benchmark FIM estimators (e.g. MNLI on RoBERTa)
$ scripts/run.sh examples/mnli_roberta.py # this will generate a list of npz files
$ scripts/viz.sh results/mnli_roberta_exact_nb128_bs64_seed42.npz # to visualize the FIM histogram
$ scripts/metrics.sh results/mnli_roberta_empirical_nb128_bs64_seed42.npz results/mnli_roberta_hutchinson_nb128_bs64_seed42.npz --gt results/mnli_roberta_exact_nb128_bs64_seed42.npz # benchmark Hutchinson FIM against empirical FIM
@inproceedings{sun2026neuromanifold,
title={Deterministic Bounds and Random Estimates of Metric Tensors on Neuromanifolds},
author={Sun, Ke},
booktitle={International Conference on Learning Representations (ICLR)},
year={2026},
url={https://openreview.net/forum?id=Ssevs8KCsU}
}