Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9893962
Integrate LieBNSPD into spd_learn modules
GitZH-Chen Mar 18, 2026
3a0ec09
Add AIM Karcher early stopping to LieBN
GitZH-Chen Mar 18, 2026
c87978a
Adjust LieBN tests for double precision
GitZH-Chen Mar 18, 2026
04dbcac
Merge branch 'main' into main
bruAristimunha Mar 20, 2026
b43486c
Rename LieBNSPD to SPDBatchNormLie and fix CI failures
bruAristimunha Mar 19, 2026
27b6a22
Refactor SPDBatchNormLie to follow functional-first pattern
bruAristimunha Mar 19, 2026
744b3b1
Add congruence parameter to SPDBatchNormLie for AIM centering
bruAristimunha Mar 19, 2026
f225ef7
Fix import sorting (ruff I001) across changed files
bruAristimunha Mar 19, 2026
7013e8d
Apply ruff formatting to LieBN module
bruAristimunha Mar 19, 2026
27ea986
Remove plot_liebn_tsmnet tutorial and geoopt dependency
bruAristimunha Mar 19, 2026
f23a3ae
Address code review findings (P0-P3)
bruAristimunha Mar 19, 2026
2e4b8f3
Use explicit dtype=torch.float64 in LieBN tests
bruAristimunha Mar 19, 2026
c303602
Apply suggestions from code review
bruAristimunha Mar 20, 2026
175369c
Fix missing blank line after imports (ruff E302)
bruAristimunha Mar 20, 2026
0786e4f
Extract frechet_mean into functional API and rename LieBN.py to liebn.py
bruAristimunha Mar 20, 2026
21a21cf
Rename SPDBatchNormLie parameter `n` to `num_features` for API consis…
bruAristimunha Mar 20, 2026
2c69dfe
Fix dtype mismatch in gallery examples by removing torch.set_default_…
bruAristimunha Mar 21, 2026
fdf0151
Cache sphinx-gallery outputs in CI to avoid re-executing unchanged ex…
bruAristimunha Mar 22, 2026
8ffc224
Add batch normalization explanation to geometric concepts and howto R…
bruAristimunha Mar 22, 2026
ae5fa40
Merge branch 'main' into main
bruAristimunha Mar 22, 2026
1e50bc3
Merge branch 'main' of https://github.com/GitZH-Chen/spd_learn into G…
bruAristimunha Mar 22, 2026
f3e0414
updating the pre-commit
bruAristimunha Mar 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ or related representations.

SPDBatchNormMean
SPDBatchNormMeanVar
SPDBatchNormLie
BatchReNorm


Expand Down
17 changes: 16 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@
from sphinx_gallery.sorting import ExplicitOrder


def _reset_torch_defaults(gallery_conf, fname):
"""Reset torch global state between sphinx-gallery examples.

Some examples call ``torch.set_default_dtype(torch.float64)`` which
persists across examples when run in the same worker process and
causes dtype-mismatch errors in subsequent examples.
"""
import torch

torch.set_default_dtype(torch.float32)


sphinx_gallery_conf = {
"examples_dirs": ["../../examples"],
"gallery_dirs": ["generated/auto_examples"],
Expand All @@ -258,10 +270,11 @@
# Point 3: Image optimization - compress images and reduce thumbnail size
"compress_images": ("images", "thumbnails"),
"thumbnail_size": (400, 280), # Smaller thumbnails for faster loading
# Order: tutorials first, then visualizations, then applied examples
# Order: tutorials, how-to guides, visualizations, then applied examples
"subsection_order": ExplicitOrder(
[
"../../examples/tutorials",
"../../examples/howto",
"../../examples/visualizations",
"../../examples/applied_examples",
]
Expand All @@ -277,6 +290,8 @@
# Include both plot_* files and tutorial_* files
"filename_pattern": r"/(plot_|tutorial_)",
"ignore_pattern": r"(__init__|spd_visualization_utils)\.py",
# Reset torch default dtype between examples to prevent float64 leakage
"reset_modules": ("matplotlib", "seaborn", _reset_torch_defaults),
# Show signature link template (includes Colab launcher)
"show_signature": False,
# First cell in generated notebooks (for Colab compatibility)
Expand Down
98 changes: 98 additions & 0 deletions docs/source/geometric_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,104 @@ where :math:`\frechet` is the Fréchet mean of the batch.
See :ref:`sphx_glr_generated_auto_examples_visualizations_plot_batchnorm_animation.py`


Batch Normalization on SPD Manifolds
=====================================

In Euclidean deep learning, batch normalization centers activations to zero mean
and unit variance, stabilizing gradient flow and accelerating convergence. On the
SPD manifold, the same principle applies — but "mean" and "variance" must respect
the curved Riemannian geometry.

Why Euclidean BN Fails for SPD Matrices
----------------------------------------

Standard batch normalization computes :math:`\hat{x} = (x - \mu) / \sigma`. For SPD
matrices this is problematic:

- **Subtraction breaks SPD**: :math:`X - M` (with :math:`M` the arithmetic mean) may not
be positive definite.
- **The swelling effect**: The Euclidean mean of SPD matrices can have a larger determinant
than any individual matrix, distorting the data distribution.
- **Scale mismatch**: SPD matrices from different subjects or sessions can have vastly
different spectral profiles; Euclidean normalization ignores this geometric structure.

Riemannian Batch Normalization
-------------------------------

:class:`~spd_learn.modules.SPDBatchNormMeanVar` addresses these issues by replacing
Euclidean operations with their Riemannian counterparts under the AIRM:

1. **Centering**: Compute the Fréchet mean :math:`\frechet` of the batch, then
apply congruence :math:`\tilde{X}_i = \frechet^{-1/2} X_i \frechet^{-1/2}` to center
the batch around the identity matrix.
2. **Variance scaling**: Compute a scalar dispersion and normalize by a learnable weight.
3. **Biasing**: Apply a learnable SPD bias via congruence.

This preserves the SPD structure at every step.

Lie Group Batch Normalization (LieBN)
--------------------------------------

:class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn` generalizes
Riemannian BN by exploiting the Lie group structure of :math:`\spd`. The key insight
is that each Riemannian metric induces a different group action for centering and biasing.

The LieBN forward pass follows five steps:

1. **Deformation** — Map SPD matrices to a codomain via the metric
(e.g., :math:`\log(X)` for LEM, Cholesky + log-diagonal for LCM, :math:`X^\theta` for AIM).
2. **Centering** — Translate the batch to zero/identity mean using the group action.
3. **Scaling** — Normalize variance by a learnable dispersion parameter.
4. **Biasing** — Translate by a learnable location parameter.
5. **Inverse deformation** — Map back to the SPD manifold.

.. list-table::
:header-rows: 1
:widths: 15 25 25 25

* - Metric
- Deformation
- Mean Computation
- Group Action
* - **LEM**
- :math:`\log(X)`
- Euclidean (closed-form)
- Additive
* - **LCM**
- Cholesky + log-diag
- Euclidean (closed-form)
- Additive
* - **AIM**
- :math:`X^\theta`
- Karcher (iterative)
- Cholesky congruence

**Choosing a metric for batch normalization:**

- **LEM**: Fastest (closed-form mean), good default for most tasks.
- **AIM**: Full affine invariance, best when data scale varies (e.g., cross-subject EEG).
- **LCM**: Fast like LEM, with Cholesky-based numerical stability.

.. code-block:: python

from spd_learn.modules import SPDBatchNormLie

# LEM is the fastest — good default
bn_lem = SPDBatchNormLie(num_features=32, metric="LEM")

# AIM for affine-invariant normalization
bn_aim = SPDBatchNormLie(num_features=32, metric="AIM", theta=1.0)

# LCM for Cholesky stability
bn_lcm = SPDBatchNormLie(num_features=32, metric="LCM")

.. seealso::

:ref:`tutorial-batch-normalization` — Hands-on tutorial comparing all BN strategies,
:ref:`howto-add-batchnorm` — Quick integration guide,
:ref:`liebn-batch-normalization` — Full benchmark reproduction across 3 datasets


References
==========

Expand Down
8 changes: 8 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ @inproceedings{kobler2022spd
url={https://proceedings.neurips.cc/paper_files/paper/2022/hash/28ef7ee7cd3e03093acc39e1272411b7-Abstract-Conference.html}
}

@inproceedings{chen2024liebn,
title={A Lie Group Approach to Riemannian Batch Normalization},
author={Chen, Ziheng and Song, Yue and Xu, Yunmei and Sebe, Nicu},
booktitle={International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=okYdj8Ysru}
}

@inproceedings{pan2022matt,
title={MAtt: A manifold attention network for EEG decoding},
author={Pan, Yue-Ting and Chou, Jing-Lun and Wei, Chun-Shu},
Expand Down
Loading
Loading