Skip to content

Integrate LieBNSPD into spd_learn#21

Merged
bruAristimunha merged 22 commits intospdlearn:mainfrom
GitZH-Chen:main
Mar 22, 2026
Merged

Integrate LieBNSPD into spd_learn#21
bruAristimunha merged 22 commits intospdlearn:mainfrom
GitZH-Chen:main

Conversation

@GitZH-Chen
Copy link
Copy Markdown
Contributor

@GitZH-Chen GitZH-Chen commented Mar 18, 2026

Summary

This PR integrates LieBNSPD into spd_learn.modules as a shared package module
and updates the related examples and tests.

It also refines the AIM Karcher mean computation in LieBNSPD by adding an
early stopping condition (which can be used in other AIM-based Karcher calculation), and updates the LieBN
tests to run in double precision.

Main changes

  • add spd_learn/modules/LieBN.py as the shared LieBNSPD implementation
  • export LieBNSPD from spd_learn.modules
  • update the LieBN example scripts and tests/test_liebn.py to import the
    shared package implementation
  • add AIM Karcher early stopping in LieBNSPD
  • extend karcher_mean_iteration with optional tangent return support for
    convergence checking
  • set tests/test_liebn.py to use float64 and increase the simulated SPD
    dimension for more stable double-precision testing

GitZH-Chen and others added 12 commits March 18, 2026 23:25
Integrate LieBNSPD as a package module and update the LieBN examples and tests to import the shared implementation from spd_learn.modules.

This commit is based on Bruno Aristimunha's initial integration draft. I reviewed it and adjusted the package wiring while adding and refining comments/docstrings around the LieBN implementation and its provenance.
Extend karcher_mean_iteration with an optional tangent return and use it in LieBNSPD to stop AIM Karcher iterations early once the tangent mean norm is sufficiently small.
Set test_liebn to use float64 by default and increase the simulated SPD dimension to keep the LieBN test configuration stable under double-precision runs.
- Rename LieBNSPD → SPDBatchNormLie to match naming convention
  (SPDBatchNormMean, SPDBatchNormMeanVar)
- Add device/dtype params to SPDBatchNormLie.__init__
- Use buffer rebinding instead of .copy_() for running stats,
  matching the pattern of other batchnorm layers and enabling
  broadcast over arbitrary leading batch dimensions
- Fix Karcher early stopping for multi-batch tensors (.max())
- Add SPDBatchNormLie to public API exports and docs autosummary
- Add mandatory params and complex dtype skip in integration tests
- Fix __file__ NameError in sphinx-gallery example
- Add geoopt to brain dependencies for TSMNet example
- Add chen2024liebn bibtex entry to references.bib
- Add rtol=0.05 to variance convergence test for LEM/LCM
Extract reusable math from the module into functional/batchnorm.py:
- spd_cholesky_congruence: Cholesky-based congruence action (L X L^T)
- lie_group_variance: weighted Fréchet variance for AIM/LEM/LCM

The module now delegates to existing functional operations:
- cholesky_log/cholesky_exp for LCM deform/inv_deform
- log_euclidean_scalar_multiply for AIM variance scaling
- matrix_log/matrix_exp/matrix_power for AIM/LEM deform
- karcher_mean_iteration for AIM Fréchet mean
Support two implementations of the AIM group action:
- "cholesky": L⁻¹ X L⁻ᵀ via Cholesky factor (original LieBN paper)
- "eig": M⁻¹/² X M⁻¹/² via eigendecomposition (spd_centering)

Both are valid SPD batch normalizations that center the mean to
identity, but use different geometric transports. Default remains
"cholesky" to match the paper.
The TSMNet tutorial depends on geoopt (RiemannianAdam) which is not
a project dependency. Remove the tutorial and the geoopt entry from
the brain extras.
P1: Validate metric parameter in SPDBatchNormLie.__init__
P1: Move .detach() from lie_group_variance to caller (keep
    functional API stateless)
P1: Move ensure_sym import to top-level in batchnorm.py
P2: Document Karcher convergence threshold (1e-5) in docstring
P2: Show congruence in extra_repr
P2: Validate metric in lie_group_variance
P2: Fix torch.set_default_dtype leak in test_liebn.py (autouse fixture)
P3: Alphabetize SPDBatchNormLie in __all__
P3: Group SPDBatchNorm* together in api.rst
P3: List new functions in batchnorm.py module docstring

Also add congruence parametrization to test_post_normalization_mean
for coverage of both cholesky and eig paths.
Replace the autouse fixture that set global default dtype with
explicit dtype=torch.float64 passed to every SPDBatchNormLie
constructor and to data generation tensors. This avoids leaking
global state to other test modules and properly exercises the
dtype parameter.
Comment thread .gitignore Outdated
Co-authored-by: Bru <b.aristimunha@gmail.com>
Add frechet_mean() to spd_learn.functional.batchnorm, unifying the
duplicated Karcher flow logic from SPDBatchNormMean, SPDBatchNormMeanVar,
SPDBatchNormLie, and the SPDIM tutorial into a single reusable function.

Rename LieBN.py to liebn.py for snake_case consistency with all other
module files, and rename karcher_steps to n_iter in SPDBatchNormLie to
match the other batchnorm modules.
…tency

All other batchnorm modules (SPDBatchNormMean, SPDBatchNormMeanVar,
BatchReNorm) use `num_features` as their matrix-size parameter.
This aligns SPDBatchNormLie with the same convention.
@bruAristimunha
Copy link
Copy Markdown
Contributor

Docs CI build analysis

The docs build is hanging for 6+ hours (vs ~23 min on main) because plot_liebn_batch_normalization.py is executed by Sphinx-Gallery (files matching plot_* are auto-executed).

Issues found

1. Training parameters are too heavy for CI

The example runs EPOCHS=200 x N_RUNS=10 x 20 configurations = ~40,000 training iterations on a CI CPU runner. Other examples in the repo use reduced parameters for docs builds:

n_clusters = 3  # Reduced from 5 for faster documentation build
max_epochs = 20  # Reduced from 100 for faster documentation build

Suggested fix — detect CI and reduce:

import os
EPOCHS = 5 if os.environ.get("CI") else 200
N_RUNS = 2 if os.environ.get("CI") else 10

2. __file__ is undefined in Sphinx-Gallery context

Line 291: os.path.dirname(__file__) raises NameError because Sphinx-Gallery executes examples without setting __file__. Replace with:

CHECKPOINT_PATH = Path(tempfile.gettempdir()) / "liebn_checkpoint.json"

3. plot_liebn_tsmnet.py needs geoopt

The failed build also showed ModuleNotFoundError: No module named 'geoopt' from plot_liebn_tsmnet.py.

CI improvements on cache-CI branch

I've added Sphinx-Gallery output caching to docs.yml and stabilized data cache keys. Once merged, the gallery will cache results so unchanged examples don't re-run.

…dtype(torch.float64)

Several examples set torch.set_default_dtype(torch.float64) globally,
which persists across sphinx-gallery examples in the same worker process.
This caused MATT and GREEN examples to fail with dtype mismatches when
their models were initialized with float64 parameters but braindecode
cast inputs to float32.

Changes:
- Remove torch.set_default_dtype(torch.float64) from 4 example scripts
- Add explicit .float() conversion where numpy float64 data enters torch
- Add _reset_torch_defaults to sphinx-gallery reset_modules as safety net
- Add howto subsection to sphinx-gallery ordering
…amples

The docs CI was taking ~6 hours because every push re-executed all
gallery examples (data downloads + model training). Sphinx-gallery
already skips unchanged examples locally via MD5 checks on the
generated/ directory, but CI never persisted this cache.

The cache key hashes all example source files so only modified
examples re-execute on subsequent runs.
…EADME

Add a new "Batch Normalization on SPD Manifolds" section to
geometric_concepts.rst covering why Euclidean BN fails, Riemannian BN,
and the LieBN 5-step pipeline with metric comparison table.

Add README.txt for the new howto gallery section.
@bruAristimunha bruAristimunha merged commit 113efad into spdlearn:main Mar 22, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants