diff --git a/docs/source/api.rst b/docs/source/api.rst
index 4d4c59d..c8e8445 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -362,6 +362,7 @@ or related representations.
SPDBatchNormMean
SPDBatchNormMeanVar
+ SPDBatchNormLie
BatchReNorm
diff --git a/docs/source/conf.py b/docs/source/conf.py
index fed715c..d7696fe 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -250,6 +250,18 @@
from sphinx_gallery.sorting import ExplicitOrder
+def _reset_torch_defaults(gallery_conf, fname):
+ """Reset torch global state between sphinx-gallery examples.
+
+ Some examples call ``torch.set_default_dtype(torch.float64)`` which
+ persists across examples when run in the same worker process and
+ causes dtype-mismatch errors in subsequent examples.
+ """
+ import torch
+
+ torch.set_default_dtype(torch.float32)
+
+
sphinx_gallery_conf = {
"examples_dirs": ["../../examples"],
"gallery_dirs": ["generated/auto_examples"],
@@ -258,10 +270,11 @@
# Point 3: Image optimization - compress images and reduce thumbnail size
"compress_images": ("images", "thumbnails"),
"thumbnail_size": (400, 280), # Smaller thumbnails for faster loading
- # Order: tutorials first, then visualizations, then applied examples
+ # Order: tutorials, how-to guides, visualizations, then applied examples
"subsection_order": ExplicitOrder(
[
"../../examples/tutorials",
+ "../../examples/howto",
"../../examples/visualizations",
"../../examples/applied_examples",
]
@@ -277,6 +290,8 @@
# Include both plot_* files and tutorial_* files
"filename_pattern": r"/(plot_|tutorial_)",
"ignore_pattern": r"(__init__|spd_visualization_utils)\.py",
+ # Reset torch default dtype between examples to prevent float64 leakage
+ "reset_modules": ("matplotlib", "seaborn", _reset_torch_defaults),
# Show signature link template (includes Colab launcher)
"show_signature": False,
# First cell in generated notebooks (for Colab compatibility)
diff --git a/docs/source/geometric_concepts.rst b/docs/source/geometric_concepts.rst
index 77c0379..8d286e8 100644
--- a/docs/source/geometric_concepts.rst
+++ b/docs/source/geometric_concepts.rst
@@ -693,6 +693,104 @@ where :math:`\frechet` is the Fréchet mean of the batch.
See :ref:`sphx_glr_generated_auto_examples_visualizations_plot_batchnorm_animation.py`
+Batch Normalization on SPD Manifolds
+=====================================
+
+In Euclidean deep learning, batch normalization centers activations to zero mean
+and unit variance, stabilizing gradient flow and accelerating convergence. On the
+SPD manifold, the same principle applies — but "mean" and "variance" must respect
+the curved Riemannian geometry.
+
+Why Euclidean BN Fails for SPD Matrices
+----------------------------------------
+
+Standard batch normalization computes :math:`\hat{x} = (x - \mu) / \sigma`. For SPD
+matrices this is problematic:
+
+- **Subtraction breaks SPD**: :math:`X - M` (with :math:`M` the arithmetic mean) may not
+ be positive definite.
+- **The swelling effect**: The Euclidean mean of SPD matrices can have a larger determinant
+ than any individual matrix, distorting the data distribution.
+- **Scale mismatch**: SPD matrices from different subjects or sessions can have vastly
+ different spectral profiles; Euclidean normalization ignores this geometric structure.
+
+Riemannian Batch Normalization
+-------------------------------
+
+:class:`~spd_learn.modules.SPDBatchNormMeanVar` addresses these issues by replacing
+Euclidean operations with their Riemannian counterparts under the AIRM:
+
+1. **Centering**: Compute the Fréchet mean :math:`\frechet` of the batch, then
+ apply congruence :math:`\tilde{X}_i = \frechet^{-1/2} X_i \frechet^{-1/2}` to center
+ the batch around the identity matrix.
+2. **Variance scaling**: Compute a scalar dispersion and normalize by a learnable weight.
+3. **Biasing**: Apply a learnable SPD bias via congruence.
+
+This preserves the SPD structure at every step.
+
+Lie Group Batch Normalization (LieBN)
+--------------------------------------
+
+:class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn` generalizes
+Riemannian BN by exploiting the Lie group structure of :math:`\spd`. The key insight
+is that each Riemannian metric induces a different group action for centering and biasing.
+
+The LieBN forward pass follows five steps:
+
+1. **Deformation** — Map SPD matrices to a codomain via the metric
+ (e.g., :math:`\log(X)` for LEM, Cholesky + log-diagonal for LCM, :math:`X^\theta` for AIM).
+2. **Centering** — Translate the batch to zero/identity mean using the group action.
+3. **Scaling** — Normalize variance by a learnable dispersion parameter.
+4. **Biasing** — Translate by a learnable location parameter.
+5. **Inverse deformation** — Map back to the SPD manifold.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 15 25 25 25
+
+ * - Metric
+ - Deformation
+ - Mean Computation
+ - Group Action
+ * - **LEM**
+ - :math:`\log(X)`
+ - Euclidean (closed-form)
+ - Additive
+ * - **LCM**
+ - Cholesky + log-diag
+ - Euclidean (closed-form)
+ - Additive
+ * - **AIM**
+ - :math:`X^\theta`
+ - Karcher (iterative)
+ - Cholesky congruence
+
+**Choosing a metric for batch normalization:**
+
+- **LEM**: Fastest (closed-form mean), good default for most tasks.
+- **AIM**: Full affine invariance, best when data scale varies (e.g., cross-subject EEG).
+- **LCM**: Fast like LEM, with Cholesky-based numerical stability.
+
+.. code-block:: python
+
+ from spd_learn.modules import SPDBatchNormLie
+
+ # LEM is the fastest — good default
+ bn_lem = SPDBatchNormLie(num_features=32, metric="LEM")
+
+ # AIM for affine-invariant normalization
+ bn_aim = SPDBatchNormLie(num_features=32, metric="AIM", theta=1.0)
+
+ # LCM for Cholesky stability
+ bn_lcm = SPDBatchNormLie(num_features=32, metric="LCM")
+
+.. seealso::
+
+ :ref:`tutorial-batch-normalization` — Hands-on tutorial comparing all BN strategies,
+ :ref:`howto-add-batchnorm` — Quick integration guide,
+ :ref:`liebn-batch-normalization` — Full benchmark reproduction across 3 datasets
+
+
References
==========
diff --git a/docs/source/references.bib b/docs/source/references.bib
index 6fc0251..824aa12 100644
--- a/docs/source/references.bib
+++ b/docs/source/references.bib
@@ -138,6 +138,14 @@ @inproceedings{kobler2022spd
url={https://proceedings.neurips.cc/paper_files/paper/2022/hash/28ef7ee7cd3e03093acc39e1272411b7-Abstract-Conference.html}
}
+@inproceedings{chen2024liebn,
+ title={A Lie Group Approach to Riemannian Batch Normalization},
+ author={Chen, Ziheng and Song, Yue and Xu, Yunmei and Sebe, Nicu},
+ booktitle={International Conference on Learning Representations},
+ year={2024},
+ url={https://openreview.net/forum?id=okYdj8Ysru}
+}
+
@inproceedings{pan2022matt,
title={MAtt: A manifold attention network for EEG decoding},
author={Pan, Yue-Ting and Chou, Jing-Lun and Wei, Chun-Shu},
diff --git a/examples/applied_examples/liebn_batch_normalization.py b/examples/applied_examples/liebn_batch_normalization.py
new file mode 100644
index 0000000..4c610d6
--- /dev/null
+++ b/examples/applied_examples/liebn_batch_normalization.py
@@ -0,0 +1,1111 @@
+"""
+.. _liebn-batch-normalization:
+
+Reproducing LieBN Paper Results (Table 4)
+==========================================
+
+This example reproduces the SPDNet experiments from Table 4 of
+Chen et al., "A Lie Group Approach to Riemannian Batch
+Normalization", ICLR 2024 :cite:p:`chen2024liebn`.
+
+We compare batch normalization strategies on HDM05 (7 configs), Radar
+(6 configs), and AFEW (7 configs) datasets. HDM05 and Radar follow the
+paper's evaluation protocol (10 independent random-split runs with
+batch-mean accuracy). AFEW uses a fixed train/val split (10 runs varying
+only model initialization).
+
+**Configurations benchmarked:**
+
+- **SPDNet**: No batch normalization
+- **SPDNetBN**: Riemannian BN (Brooks et al. + variance normalization)
+- **LieBN-AIM**: LieBN under the Affine-Invariant Metric (theta=1, 1.5)
+- **LieBN-LEM**: LieBN under the Log-Euclidean Metric
+- **LieBN-LCM**: LieBN under the Log-Cholesky Metric (theta=1, 0.5, -0.5)
+
+.. note::
+
+ New to SPD batch normalization? Start with the
+ :ref:`tutorial-batch-normalization` tutorial for an introduction, or
+ see :ref:`howto-add-batchnorm` for a quick integration guide.
+
+.. contents:: This example covers:
+ :local:
+ :depth: 2
+
+"""
+
+######################################################################
+# Setup and Imports
+# -----------------
+#
+
+import json
+import os
+import random
+import tarfile
+import tempfile
+import time
+import urllib.request
+import warnings
+import zipfile
+
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+
+from joblib import Parallel, delayed
+from torch.utils.data import DataLoader, TensorDataset
+
+from spd_learn.functional import ensure_sym
+from spd_learn.modules import (
+ BiMap,
+ LogEig,
+ ReEig,
+ SPDBatchNormLie,
+ SPDBatchNormMeanVar,
+)
+
+
+# Suppress noisy warnings from matplotlib and torch internals;
+# keep UserWarning and RuntimeWarning visible for diagnostic signals.
+warnings.filterwarnings("ignore", category=FutureWarning)
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", module="matplotlib")
+
+
+def set_reproducibility(seed=1024):
+ """Set random seeds and enable deterministic behavior."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ try:
+ torch.use_deterministic_algorithms(True, warn_only=True)
+ except TypeError:
+ torch.use_deterministic_algorithms(True)
+ if hasattr(torch.backends, "cudnn"):
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+GLOBAL_SEED = 1024
+set_reproducibility(GLOBAL_SEED)
+DATA_DIR = Path("data")
+DATA_DIR.mkdir(exist_ok=True)
+
+######################################################################
+# SPDNet Architecture & Training Setup
+# -------------------------------------
+#
+# We build a multi-layer SPDNet following the reference architecture:
+#
+# - Intermediate layers: ``BiMap -> [BN] -> ReEig``
+# - Final layer: ``BiMap -> [BN]`` (no ReEig)
+# - Classifier: ``LogEig -> flatten -> Linear``
+#
+
+
+def make_bn(n, bn_type, bn_kwargs):
+ """Create a batch normalization layer."""
+ if bn_type == "SPDBN":
+ return SPDBatchNormMeanVar(n, momentum=bn_kwargs.get("momentum", 0.1))
+ elif bn_type == "LieBN":
+ return SPDBatchNormLie(n, **bn_kwargs)
+ else:
+ raise ValueError(f"Unknown bn_type: {bn_type}")
+
+
+class SPDNetModel(nn.Module):
+ """Multi-layer SPDNet with optional batch normalization.
+
+ Parameters
+ ----------
+ dims : list of int
+ Sequence of SPD matrix dimensions, e.g. [93, 30].
+ n_classes : int
+ Number of output classes.
+ bn_type : str or None
+ None (no BN), 'SPDBN', or 'LieBN'.
+ bn_kwargs : dict or None
+ Keyword arguments for the BN layer.
+ """
+
+ def __init__(self, dims, n_classes, bn_type=None, bn_kwargs=None):
+ super().__init__()
+ layers = []
+ for i in range(len(dims) - 1):
+ layers.append(BiMap(dims[i], dims[i + 1]))
+ if bn_type is not None:
+ layers.append(make_bn(dims[i + 1], bn_type, bn_kwargs or {}))
+ if i < len(dims) - 2:
+ layers.append(ReEig())
+ self.features = nn.Sequential(*layers)
+ n_out = dims[-1]
+ # upper=False uses full n^2 features (matches reference's dims[-1]**2).
+ self.logeig = LogEig(upper=False, flatten=True)
+ self.classifier = nn.Linear(n_out**2, n_classes)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.logeig(x)
+ return self.classifier(x)
+
+
+######################################################################
+# Training and Evaluation Utilities
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# We define a single atomic training function that is dispatched via
+# ``joblib`` with the loky (process-based) backend for true parallelism.
+#
+# **Checkpointing**: Results are saved per-run to a checkpoint file
+# so that a crashed/interrupted experiment can be resumed without
+# re-running completed work.
+#
+# **Performance tuning**:
+#
+# - Each worker uses 1 BLAS thread (``torch.set_num_threads(1)``)
+# since Apple Silicon Accelerate's ``eigh`` is fastest single-threaded.
+# With 14 workers this fully utilizes all CPU cores.
+# - ``optimizer.zero_grad(set_to_none=True)`` avoids memset overhead.
+# - Data arrives as numpy arrays for joblib memmapping (no pickle).
+#
+
+N_RUNS = 10
+EPOCHS = 200
+BATCH_SIZE = 30
+LR = 5e-3
+
+# Checkpoint file: per-run results saved as they complete.
+try:
+ CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), "liebn_checkpoint.json")
+except NameError:
+ CHECKPOINT_PATH = os.path.join(tempfile.gettempdir(), "liebn_checkpoint.json")
+
+
+def _load_checkpoint():
+ """Load existing checkpoint, return dict of completed runs."""
+ if os.path.exists(CHECKPOINT_PATH):
+ with open(CHECKPOINT_PATH) as f:
+ return json.load(f)
+ return {}
+
+
+def _save_checkpoint(checkpoint):
+ """Atomically save checkpoint (write tmp + rename)."""
+ tmp = CHECKPOINT_PATH + ".tmp"
+ with open(tmp, "w") as f:
+ json.dump(checkpoint, f, indent=2)
+ os.replace(tmp, CHECKPOINT_PATH)
+
+
+def _train_single_run(
+ run_seed,
+ X_train,
+ y_train,
+ X_test,
+ y_test,
+ dims,
+ n_classes,
+ bn_type,
+ bn_kwargs,
+ epochs=EPOCHS,
+ batch_size=BATCH_SIZE,
+ lr=LR,
+):
+ """Train and evaluate a single run (atomic unit of work).
+
+ Data arrives as numpy arrays (for joblib memmapping) and is
+ converted to tensors inside the worker process.
+
+ Returns (accuracy, fit_time) or (NaN, fit_time) on failure.
+ """
+ # One BLAS thread per worker; eigh on Apple Silicon Accelerate is
+ # fastest single-threaded. With 14 workers we use all 14 cores.
+ torch.set_num_threads(1)
+ torch.set_default_dtype(torch.float64)
+
+ random.seed(run_seed)
+ np.random.seed(run_seed)
+ torch.manual_seed(run_seed)
+
+ # Convert numpy → tensors inside each worker process.
+ # Use torch.tensor() (not as_tensor) to copy from read-only memmaps.
+ X_train = torch.tensor(X_train, dtype=torch.float64)
+ y_train = torch.tensor(y_train, dtype=torch.long)
+ X_test = torch.tensor(X_test, dtype=torch.float64)
+ y_test = torch.tensor(y_test, dtype=torch.long)
+
+ train_gen = torch.Generator()
+ train_gen.manual_seed(run_seed)
+
+ model = SPDNetModel(dims, n_classes, bn_type=bn_type, bn_kwargs=bn_kwargs)
+
+ train_loader = DataLoader(
+ TensorDataset(X_train, y_train),
+ batch_size=batch_size,
+ shuffle=True,
+ generator=train_gen,
+ )
+ test_loader = DataLoader(
+ TensorDataset(X_test, y_test),
+ batch_size=batch_size,
+ shuffle=False,
+ )
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True)
+ criterion = nn.CrossEntropyLoss()
+
+ epoch_times = []
+ try:
+ for epoch in range(epochs):
+ t0 = time.time()
+ model.train()
+ for xb, yb in train_loader:
+ optimizer.zero_grad(set_to_none=True)
+ loss = criterion(model(xb), yb)
+ loss.backward()
+ optimizer.step()
+ epoch_times.append(time.time() - t0)
+
+ # Batch-mean accuracy (matches paper's training_script.py)
+ model.eval()
+ batch_accs = []
+ with torch.no_grad():
+ for xb, yb in test_loader:
+ acc = (model(xb).argmax(1) == yb).sum().item() / yb.shape[0]
+ batch_accs.append(acc)
+
+ return np.mean(batch_accs) * 100.0, np.mean(epoch_times[-10:])
+ except torch._C._LinAlgError as e:
+ warnings.warn(f"Run {run_seed} failed (LinAlgError): {e}")
+ fit_time = np.mean(epoch_times[-10:]) if epoch_times else 0.0
+ return float("nan"), fit_time
+ except RuntimeError as e:
+ if "linalg" in str(e).lower() or "cholesky" in str(e).lower():
+ warnings.warn(f"Run {run_seed} failed (linalg RuntimeError): {e}")
+ fit_time = np.mean(epoch_times[-10:]) if epoch_times else 0.0
+ return float("nan"), fit_time
+ raise
+
+
+######################################################################
+# Dataset Loading: HDM05
+# ----------------------
+#
+# HDM05 contains 2086 pre-computed 93x93 SPD covariance matrices
+# representing 117 motion capture classes.
+#
+# - Architecture: ``[93, 30]``
+# - Source: `HDM05 Motion Capture Database `_
+#
+
+
+def download_and_extract(url, dest_dir, zip_name, extract_tgz=None):
+ """Download a zip file and extract it."""
+ zip_path = dest_dir / zip_name
+ if not zip_path.exists():
+ print(f"Downloading {zip_name}...")
+ urllib.request.urlretrieve(url, zip_path)
+ with zipfile.ZipFile(zip_path, "r") as zf:
+ zf.extractall(dest_dir)
+ if extract_tgz:
+ tgz_path = dest_dir / extract_tgz
+ if tgz_path.exists():
+ with tarfile.open(tgz_path, "r:gz") as tf:
+ tf.extractall(dest_dir)
+
+
+def load_hdm05(data_dir):
+ """Load HDM05 dataset: pre-computed 93x93 SPD covariance matrices."""
+ hdm_path = data_dir / "HDM05"
+ if not hdm_path.exists():
+ download_and_extract(
+ "https://www.dropbox.com/scl/fi/x2ouxjwqj3zrb1idgkg2g/"
+ "HDM05.zip?rlkey=4f90ktgzfz28x3i2i4ylu6dvu&dl=1",
+ data_dir,
+ "HDM05.zip",
+ )
+ names = sorted(f for f in os.listdir(hdm_path) if f.endswith(".npy"))
+ X_list, y_list = [], []
+ for name in names:
+ x = np.load(hdm_path / name).real
+ label = int(name.split(".")[0].split("_")[-1])
+ X_list.append(x)
+ y_list.append(label)
+ X = torch.from_numpy(np.stack(X_list)).double()
+ y = torch.from_numpy(np.array(y_list)).long()
+ print(
+ f"HDM05: {X.shape[0]} samples, {len(set(y_list))} classes, "
+ f"matrix size {X.shape[1]}x{X.shape[2]}"
+ )
+ return X, y
+
+
+X_hdm, y_hdm = load_hdm05(DATA_DIR)
+eigvals_hdm = torch.linalg.eigvalsh(X_hdm)
+print(
+ f"HDM05 SPD check: min eigenvalue = {eigvals_hdm.min():.2e}, "
+ f"max = {eigvals_hdm.max():.2e}"
+)
+
+
+######################################################################
+# HDM05 Experiments
+# -----------------
+#
+# We run 10 independent random-split experiments with 7 configurations
+# matching Table 4b of the paper: SPDNet, SPDNetBN, AIM-(1), LEM-(1),
+# LCM-(1), AIM-(1.5), and LCM-(0.5).
+#
+# Split: 50/50 train/test (random shuffle).
+# Training: 200 epochs, batch_size=30, lr=5e-3, Adam with amsgrad.
+#
+
+HDM05_DIMS = [93, 30]
+HDM05_CLASSES = len(torch.unique(y_hdm))
+print(f"HDM05: dims={HDM05_DIMS}, n_classes={HDM05_CLASSES}")
+
+hdm05_configs = {
+ "SPDNet": {"bn_type": None, "bn_kwargs": None},
+ "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}},
+ "AIM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "AIM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LEM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LEM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "AIM-(1.5)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "AIM",
+ "theta": 1.5,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(0.5)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": 0.5,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+}
+
+######################################################################
+# Dataset Loading: Radar
+# ----------------------
+#
+# The Radar dataset contains 3000 complex time-frequency signals
+# (3 gesture classes), converted to 20x20 SPD matrices via
+# covariance pooling.
+#
+# Signal processing pipeline:
+#
+# 1. Split complex signals into overlapping windows (size=20, hop=10)
+# 2. Compute real covariance from complex windowed signal
+# 3. Apply ReEig to ensure well-conditioned SPD output
+#
+# - Architecture: ``[20, 16, 12]``
+#
+
+
+def _split_signal_cplx(x, window_size=20, hop_length=10):
+ """Window complex signals into overlapping segments.
+
+ Input: (batch, 2, T) where dim=1 is [real, imag]
+ Output: (batch, 2, window_size, T')
+ """
+ x_re = x[:, 0:1, :]
+ x_im = x[:, 1:2, :]
+ x_re_w = x_re.unfold(2, window_size, hop_length)
+ x_im_w = x_im.unfold(2, window_size, hop_length)
+ x_re_out = x_re_w.squeeze(1).permute(0, 2, 1)
+ x_im_out = x_im_w.squeeze(1).permute(0, 2, 1)
+ return torch.stack([x_re_out, x_im_out], dim=1)
+
+
+def _cov_pool_cplx(f):
+ """Compute real covariance matrix from complex windowed signal.
+
+ Input: (batch, 2, n, T)
+ Output: (batch, n, n) real SPD covariance matrix
+ """
+ f_re = f[:, 0, :, :].double()
+ f_im = f[:, 1, :, :].double()
+ f_re = f_re - f_re.mean(-1, keepdim=True)
+ f_im = f_im - f_im.mean(-1, keepdim=True)
+ T = f.shape[-1]
+ X_Re = (f_re @ f_re.mT + f_im @ f_im.mT) / (T - 1)
+ return ensure_sym(X_Re)
+
+
+def load_radar(data_dir):
+ """Load Radar dataset: complex signals -> 20x20 SPD covariance matrices."""
+ radar_path = data_dir / "radar"
+ if not radar_path.exists():
+ download_and_extract(
+ "https://www.dropbox.com/s/dfnlx2bnyh3kjwy/data.zip?e=1&dl=1",
+ data_dir,
+ "data.zip",
+ extract_tgz="data/radar.tgz",
+ )
+ names = sorted(f for f in os.listdir(radar_path) if f.endswith(".npy"))
+ signals, labels = [], []
+ for name in names:
+ x = np.load(radar_path / name)
+ x_ri = np.stack([x.real, x.imag], axis=0)
+ signals.append(x_ri)
+ labels.append(int(name.split(".")[0].split("_")[-1]))
+ signals_t = torch.from_numpy(np.stack(signals)).float()
+ y = torch.from_numpy(np.array(labels)).long()
+ with torch.no_grad():
+ windowed = _split_signal_cplx(signals_t, window_size=20, hop_length=10)
+ X_cov = _cov_pool_cplx(windowed)
+ reeig = ReEig()
+ X = reeig(X_cov)
+ print(
+ f"Radar: {X.shape[0]} samples, {len(set(labels))} classes, "
+ f"matrix size {X.shape[1]}x{X.shape[2]}"
+ )
+ return X, y
+
+
+X_radar, y_radar = load_radar(DATA_DIR)
+eigvals_radar = torch.linalg.eigvalsh(X_radar)
+print(
+ f"Radar SPD check: min eigenvalue = {eigvals_radar.min():.2e}, "
+ f"max = {eigvals_radar.max():.2e}"
+)
+
+
+######################################################################
+# Dataset Loading: AFEW
+# ---------------------
+#
+# AFEW (Acted Facial Expressions in the Wild) contains pre-computed
+# 400x400 SPD covariance matrices for facial expression recognition
+# (7 emotion classes). The dataset comes pre-split into train/val.
+#
+# - Architecture: ``[400, 200, 100, 50]``
+# - Source: `AFEW Dataset `_
+#
+
+
+def load_afew(data_dir):
+ """Load AFEW dataset: pre-computed 400x400 SPD covariance matrices."""
+ afew_path = data_dir / "afew"
+ if not afew_path.exists():
+ # Try extracting from data/afew.tgz
+ tgz_candidates = [
+ data_dir / "data" / "afew.tgz",
+ data_dir / "afew.tgz",
+ ]
+ for tgz in tgz_candidates:
+ if tgz.exists():
+ with tarfile.open(tgz, "r:gz") as tf:
+ tf.extractall(data_dir)
+ break
+ else:
+ raise FileNotFoundError(
+ "AFEW data not found. Place afew.tgz in the data/ directory."
+ )
+
+ train_path = afew_path / "train"
+ val_path = afew_path / "val"
+
+ def _load_split(split_path):
+ names = sorted(f for f in os.listdir(split_path) if f.endswith(".npy"))
+ X_list, y_list = [], []
+ for name in names:
+ x = np.load(split_path / name).real
+ label = int(name.split(".")[0].split("_")[-1])
+ X_list.append(x)
+ y_list.append(label)
+ return (
+ torch.from_numpy(np.stack(X_list)).double(),
+ torch.from_numpy(np.array(y_list)).long(),
+ )
+
+ X_train, y_train = _load_split(train_path)
+ X_val, y_val = _load_split(val_path)
+ print(
+ f"AFEW: {X_train.shape[0]} train + {X_val.shape[0]} val samples, "
+ f"{len(set(y_train.tolist()) | set(y_val.tolist()))} classes, "
+ f"matrix size {X_train.shape[1]}x{X_train.shape[2]}"
+ )
+ return X_train, y_train, X_val, y_val
+
+
+X_afew_train, y_afew_train, X_afew_val, y_afew_val = load_afew(DATA_DIR)
+eigvals_afew = torch.linalg.eigvalsh(X_afew_train)
+print(
+ f"AFEW SPD check: min eigenvalue = {eigvals_afew.min():.2e}, "
+ f"max = {eigvals_afew.max():.2e}"
+)
+
+
+######################################################################
+# Radar Experiments
+# -----------------
+#
+# We run 10 independent random-split experiments with 6 configurations
+# matching Table 4a: SPDNet, SPDNetBN, AIM-(1), LEM-(1), LCM-(1),
+# and LCM-(-0.5).
+#
+# Split: 50/25/25 train/val/test (random shuffle; val discarded).
+# We use ``test_ratio=0.25, val_ratio=0.25`` to match the paper's split.
+#
+
+RADAR_DIMS = [20, 16, 12]
+RADAR_CLASSES = len(torch.unique(y_radar))
+print(f"Radar: dims={RADAR_DIMS}, n_classes={RADAR_CLASSES}")
+
+radar_configs = {
+ "SPDNet": {"bn_type": None, "bn_kwargs": None},
+ "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}},
+ "AIM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "AIM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LEM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LEM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(-0.5)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": -0.5,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+}
+
+######################################################################
+# AFEW Experiments
+# ----------------
+#
+# AFEW (Acted Facial Expressions in the Wild) has a fixed train/val split,
+# so we run 10 experiments varying only model initialization.
+#
+# Architecture: ``[400, 200, 100, 50]`` (from the original SPDNet paper,
+# Huang & Van Gool, 2017).
+#
+# .. note::
+#
+# The LieBN paper (Table 4) uses the **FPHA** dataset (63x63, 45 classes)
+# rather than AFEW (400x400, 7 classes). We include AFEW as an additional
+# benchmark; no paper comparison numbers are available.
+#
+
+AFEW_DIMS = [400, 200, 100, 50]
+AFEW_CLASSES = len(set(y_afew_train.tolist()) | set(y_afew_val.tolist()))
+print(f"AFEW: dims={AFEW_DIMS}, n_classes={AFEW_CLASSES}")
+
+afew_configs = {
+ "SPDNet": {"bn_type": None, "bn_kwargs": None},
+ "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}},
+ "AIM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "AIM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LEM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LEM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(1)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": 1.0,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "AIM-(1.5)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "AIM",
+ "theta": 1.5,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+ "LCM-(0.5)": {
+ "bn_type": "LieBN",
+ "bn_kwargs": {
+ "metric": "LCM",
+ "theta": 0.5,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "momentum": 0.1,
+ },
+ },
+}
+
+######################################################################
+# Run All Experiments (Parallelized with Checkpointing)
+# -----------------------------------------------------
+#
+# Experiments are dispatched per-dataset via ``joblib`` (loky backend)
+# for true multi-process parallelism. Each worker uses 1 BLAS thread
+# since Apple Silicon Accelerate's ``eigh`` is fastest single-threaded,
+# with 14 workers fully utilizing all CPU cores.
+#
+# **Checkpointing**: After each dataset completes, all finished
+# per-run results are saved to ``liebn_checkpoint.json``. On restart,
+# already-completed (dataset, method, run) tuples are skipped.
+#
+# Total jobs: HDM05 (7 x 10) + Radar (6 x 10) + AFEW (7 x 10) = 200
+#
+
+checkpoint = _load_checkpoint()
+n_cached = len(checkpoint)
+if n_cached > 0:
+ print(f"Loaded checkpoint with {n_cached} completed runs.")
+
+
+def _run_dataset_jobs(dataset_name, configs, jobs_and_keys, checkpoint):
+ """Dispatch jobs for one dataset, skipping already-checkpointed runs.
+
+ Returns dict of {method: [(acc, fit_time), ...]} for this dataset.
+ Mutates ``checkpoint`` in-place with newly completed runs.
+ """
+
+ filtered_jobs = []
+ filtered_keys = []
+ cached_results = defaultdict(list)
+
+ for (ds, method, run_idx), job in jobs_and_keys:
+ key = f"{ds}|{method}|{run_idx}"
+ if key in checkpoint:
+ acc = checkpoint[key]["acc"]
+ # Restore NaN for failed runs (stored as null in JSON).
+ if acc is None:
+ acc = float("nan")
+ cached_results[method].append((acc, checkpoint[key]["fit_time"]))
+ else:
+ filtered_jobs.append(job)
+ filtered_keys.append((ds, method, run_idx))
+
+ n_skip = len(jobs_and_keys) - len(filtered_jobs)
+ n_todo = len(filtered_jobs)
+ if n_skip > 0:
+ print(f" {dataset_name}: {n_skip} runs cached, {n_todo} remaining.")
+ if n_todo == 0:
+ print(f" {dataset_name}: all runs cached, nothing to do.")
+ else:
+ t0 = time.time()
+ # n_jobs=-1: 14 workers x 1 BLAS thread = 14 cores fully used.
+ raw = Parallel(n_jobs=-1, verbose=10)(filtered_jobs)
+ elapsed = time.time() - t0
+ print(f" {dataset_name}: {n_todo} runs finished in {elapsed:.1f}s")
+
+ # Save to checkpoint immediately.
+ for (ds, method, run_idx), result in zip(filtered_keys, raw):
+ key = f"{ds}|{method}|{run_idx}"
+ acc, ft = result
+ checkpoint[key] = {
+ "acc": None if np.isnan(acc) else acc,
+ "fit_time": ft,
+ "status": "failed" if np.isnan(acc) else "ok",
+ }
+ cached_results[method].append(result)
+ _save_checkpoint(checkpoint)
+ print(f" Checkpoint saved ({len(checkpoint)} total runs).")
+
+ return dict(cached_results)
+
+
+def _aggregate(runs):
+ """Aggregate per-run (acc, fit_time) tuples."""
+ accs = [r[0] for r in runs if not np.isnan(r[0])]
+ fts = [r[1] for r in runs if not np.isnan(r[0])]
+ n_failed = sum(1 for r in runs if np.isnan(r[0]))
+ if n_failed > 0:
+ warnings.warn(f"{n_failed}/{len(runs)} runs failed (NaN)")
+ if not accs:
+ return {
+ "mean": 0.0,
+ "std": 0.0,
+ "max": 0.0,
+ "folds": [],
+ "fit_time": 0.0,
+ }
+ return {
+ "mean": np.mean(accs),
+ "std": np.std(accs),
+ "max": np.max(accs),
+ "folds": accs,
+ "fit_time": np.mean(fts),
+ }
+
+
+# ---- Prepare HDM05 jobs (fixed 50/50 split) ----
+n_hdm = len(X_hdm)
+rng_hdm = np.random.RandomState(GLOBAL_SEED)
+perm_hdm = rng_hdm.permutation(n_hdm)
+n_test_hdm = int(0.5 * n_hdm)
+Xtr_hdm = X_hdm[perm_hdm[n_test_hdm:]].numpy()
+ytr_hdm = y_hdm[perm_hdm[n_test_hdm:]].numpy()
+Xte_hdm = X_hdm[perm_hdm[:n_test_hdm]].numpy()
+yte_hdm = y_hdm[perm_hdm[:n_test_hdm]].numpy()
+
+hdm05_jobs = []
+for name, cfg in hdm05_configs.items():
+ for i in range(N_RUNS):
+ job = delayed(_train_single_run)(
+ GLOBAL_SEED + i,
+ Xtr_hdm,
+ ytr_hdm,
+ Xte_hdm,
+ yte_hdm,
+ HDM05_DIMS,
+ HDM05_CLASSES,
+ cfg["bn_type"],
+ cfg["bn_kwargs"],
+ )
+ hdm05_jobs.append((("HDM05", name, i), job))
+
+# ---- Prepare Radar jobs (per-run 50/25/25 split) ----
+n_radar = len(X_radar)
+n_test_radar = int(0.25 * n_radar)
+n_val_radar = int(0.25 * n_radar)
+X_radar_np = X_radar.numpy()
+y_radar_np = y_radar.numpy()
+
+radar_jobs = []
+for name, cfg in radar_configs.items():
+ for i in range(N_RUNS):
+ run_seed = GLOBAL_SEED + i
+ rng = np.random.RandomState(run_seed)
+ perm = rng.permutation(n_radar)
+ Xte_r = X_radar_np[perm[:n_test_radar]]
+ yte_r = y_radar_np[perm[:n_test_radar]]
+ Xtr_r = X_radar_np[perm[n_test_radar + n_val_radar :]]
+ ytr_r = y_radar_np[perm[n_test_radar + n_val_radar :]]
+ job = delayed(_train_single_run)(
+ run_seed,
+ Xtr_r,
+ ytr_r,
+ Xte_r,
+ yte_r,
+ RADAR_DIMS,
+ RADAR_CLASSES,
+ cfg["bn_type"],
+ cfg["bn_kwargs"],
+ )
+ radar_jobs.append((("Radar", name, i), job))
+
+# ---- Prepare AFEW jobs (fixed train/val split) ----
+X_afew_train_np = X_afew_train.numpy()
+y_afew_train_np = y_afew_train.numpy()
+X_afew_val_np = X_afew_val.numpy()
+y_afew_val_np = y_afew_val.numpy()
+
+afew_jobs = []
+for name, cfg in afew_configs.items():
+ for i in range(N_RUNS):
+ job = delayed(_train_single_run)(
+ GLOBAL_SEED + i,
+ X_afew_train_np,
+ y_afew_train_np,
+ X_afew_val_np,
+ y_afew_val_np,
+ AFEW_DIMS,
+ AFEW_CLASSES,
+ cfg["bn_type"],
+ cfg["bn_kwargs"],
+ )
+ afew_jobs.append((("AFEW", name, i), job))
+
+# ---- Run datasets sequentially, saving after each ----
+# This way Radar+HDM05 results are safe even if AFEW crashes.
+
+total_jobs = len(hdm05_jobs) + len(radar_jobs) + len(afew_jobs)
+print(
+ f"\nTotal: {total_jobs} training runs "
+ f"({len(hdm05_configs)} + {len(radar_configs)} + "
+ f"{len(afew_configs)} methods x {N_RUNS} runs)."
+)
+
+t_wall_start = time.time()
+
+print("\n--- HDM05 ---")
+hdm05_raw = _run_dataset_jobs("HDM05", hdm05_configs, hdm05_jobs, checkpoint)
+print("\n--- Radar ---")
+radar_raw = _run_dataset_jobs("Radar", radar_configs, radar_jobs, checkpoint)
+print("\n--- AFEW ---")
+afew_raw = _run_dataset_jobs("AFEW", afew_configs, afew_jobs, checkpoint)
+
+t_wall = time.time() - t_wall_start
+print(f"\nAll experiments finished in {t_wall:.1f}s")
+
+# ---- Aggregate per-method results ----
+hdm05_results = {m: _aggregate(r) for m, r in hdm05_raw.items()}
+radar_results = {m: _aggregate(r) for m, r in radar_raw.items()}
+afew_results = {m: _aggregate(r) for m, r in afew_raw.items()}
+
+for ds, results in [
+ ("HDM05", hdm05_results),
+ ("Radar", radar_results),
+ ("AFEW", afew_results),
+]:
+ print(f"\n{ds}:")
+ for m, r in results.items():
+ print(
+ f" {m}: {r['mean']:.2f} +/- {r['std']:.2f} "
+ f"(max={r['max']:.2f}, fit_time={r['fit_time']:.2f}s)"
+ )
+
+
+######################################################################
+# Results Comparison & Visualization
+# -----------------------------------
+#
+# We compare our reproduction results against the paper's Table 4 numbers
+# for HDM05 and Radar. AFEW results are shown without paper comparison
+# (the paper uses FPHA, a different dataset).
+#
+
+# Paper Table 4 numbers: (mean, std, max, fit_time)
+paper_results = {
+ "Radar": {
+ "SPDNet": (93.25, 1.10, 94.4, 0.98),
+ "SPDNetBN": (94.85, 0.99, 96.13, 1.56),
+ "AIM-(1)": (95.47, 0.90, 96.27, 1.62),
+ "LEM-(1)": (94.89, 1.04, 96.8, 1.28),
+ "LCM-(1)": (93.52, 1.07, 95.2, 1.11),
+ "LCM-(-0.5)": (94.80, 0.71, 95.73, 1.43),
+ },
+ "HDM05": {
+ "SPDNet": (59.13, 0.67, 60.34, 0.57),
+ "SPDNetBN": (66.72, 0.52, 67.66, 0.97),
+ "AIM-(1)": (67.79, 0.65, 68.75, 1.14),
+ "LEM-(1)": (65.05, 0.63, 66.05, 0.87),
+ "LCM-(1)": (66.68, 0.71, 68.52, 0.66),
+ "AIM-(1.5)": (68.16, 0.68, 69.25, 1.46),
+ "LCM-(0.5)": (70.84, 0.92, 72.27, 1.01),
+ },
+ "AFEW": {},
+}
+
+radar_methods = list(radar_configs.keys())
+hdm05_methods = list(hdm05_configs.keys())
+afew_methods = list(afew_configs.keys())
+
+######################################################################
+# Results Tables
+# ~~~~~~~~~~~~~~
+#
+
+
+def _print_table(dataset, methods, our_results, paper):
+ """Print comparison table for one dataset."""
+ hdr = (
+ f"{'Method':<14} | {'Fit Time':>8} | "
+ f"{'Mean+-STD (Ours)':>18} {'Max (Ours)':>10} | "
+ f"{'Mean+-STD (Paper)':>18} {'Max (Paper)':>11}"
+ )
+ sep = "=" * len(hdr)
+ print(f"\n{dataset}")
+ print(sep)
+ print(hdr)
+ print(sep)
+ for m in methods:
+ ours = our_results.get(m, {})
+ p = paper.get(m)
+ ft = f"{ours.get('fit_time', 0):.2f}" if ours else "---"
+ o_str = f"{ours['mean']:.2f}+-{ours['std']:.2f}" if ours else "---"
+ o_max = f"{ours['max']:.2f}" if ours else "---"
+ if p:
+ p_str = f"{p[0]:.2f}+-{p[1]:.2f}"
+ p_max = f"{p[2]:.2f}"
+ else:
+ p_str, p_max = "---", "---"
+ print(f"{m:<14} | {ft:>8} | {o_str:>18} {o_max:>10} | {p_str:>18} {p_max:>11}")
+ print(sep)
+
+
+_print_table("Radar", radar_methods, radar_results, paper_results["Radar"])
+_print_table("HDM05", hdm05_methods, hdm05_results, paper_results["HDM05"])
+_print_table("AFEW", afew_methods, afew_results, paper_results["AFEW"])
+
+######################################################################
+# Save results to JSON for reproducibility.
+#
+
+try:
+ results_path = os.path.join(os.path.dirname(__file__), "liebn_table4_results.json")
+except NameError:
+ results_path = os.path.join(tempfile.gettempdir(), "liebn_table4_results.json")
+results_to_save = {
+ "radar": {k: dict(v) for k, v in radar_results.items()},
+ "hdm05": {k: dict(v) for k, v in hdm05_results.items()},
+ "afew": {k: dict(v) for k, v in afew_results.items()},
+ "paper_results": {
+ ds: {
+ m: {"mean": v[0], "std": v[1], "max": v[2], "fit_time": v[3]}
+ for m, v in mv.items()
+ }
+ for ds, mv in paper_results.items()
+ if mv # skip empty (AFEW has no paper numbers)
+ },
+}
+with open(results_path, "w") as f:
+ json.dump(results_to_save, f, indent=2)
+print(f"\nResults saved to {results_path}")
+
+######################################################################
+# Comparison Bar Chart
+# ~~~~~~~~~~~~~~~~~~~~
+#
+
+fig, axes = plt.subplots(1, 3, figsize=(22, 5))
+
+dataset_info = [
+ ("Radar", radar_results, radar_methods),
+ ("HDM05", hdm05_results, hdm05_methods),
+ ("AFEW", afew_results, afew_methods),
+]
+
+for ax, (dataset, our_results, methods) in zip(axes, dataset_info):
+ x_pos = np.arange(len(methods))
+ has_paper = bool(paper_results.get(dataset))
+
+ ours_means = [our_results[m]["mean"] for m in methods]
+ ours_stds = [our_results[m]["std"] for m in methods]
+
+ if has_paper:
+ width = 0.35
+ paper_means = [paper_results[dataset][m][0] for m in methods]
+ paper_stds = [paper_results[dataset][m][1] for m in methods]
+
+ ax.bar(
+ x_pos - width / 2,
+ ours_means,
+ width,
+ yerr=ours_stds,
+ label="Ours",
+ capsize=3,
+ color="#3498db",
+ alpha=0.85,
+ )
+ ax.bar(
+ x_pos + width / 2,
+ paper_means,
+ width,
+ yerr=paper_stds,
+ label="Paper",
+ capsize=3,
+ color="#e74c3c",
+ alpha=0.85,
+ )
+ all_vals = ours_means + paper_means
+ else:
+ width = 0.5
+ ax.bar(
+ x_pos,
+ ours_means,
+ width,
+ yerr=ours_stds,
+ label="Ours",
+ capsize=3,
+ color="#3498db",
+ alpha=0.85,
+ )
+ all_vals = ours_means
+
+ ax.set_xticks(x_pos)
+ ax.set_xticklabels(methods, rotation=30, ha="right")
+ ax.set_ylabel("Accuracy (%)")
+ ax.set_title(dataset)
+ ax.legend()
+ ax.grid(axis="y", alpha=0.3)
+ ymin = min(all_vals) - 5
+ ymax = max(all_vals) + 3
+ ax.set_ylim(ymin, ymax)
+
+plt.suptitle("LieBN Batch Normalization: SPDNet Results", fontweight="bold")
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# References
+# ----------
+#
+# .. bibliography::
+# :filter: docname in docnames
+#
diff --git a/examples/applied_examples/plot_source_free_domain.py b/examples/applied_examples/plot_source_free_domain.py
index eb53aee..4868a31 100644
--- a/examples/applied_examples/plot_source_free_domain.py
+++ b/examples/applied_examples/plot_source_free_domain.py
@@ -96,75 +96,10 @@
# SPDIM Geometric Operations
# --------------------------
#
-# We define two core geometric operations needed for the SPDIM pipeline.
-# These will be included in a future release of ``spd_learn.functional``.
+# The Fréchet mean and geodesic distances used by SPDIM are available
+# directly from ``spd_learn.functional``.
#
-from spd_learn.functional import (
- get_epsilon,
- matrix_exp,
- matrix_log,
- matrix_sqrt_inv,
-)
-
-
-def frechet_mean(X, max_iter=50, return_distances=False):
- r"""Compute the Fréchet mean under the AIRM.
-
- .. math::
-
- \bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n}
- \sum_{i=1}^{N} d_{\text{AIRM}}^2(G, X_i)
-
- Uses adaptive step-size Karcher flow.
- """
- eps = get_epsilon(X.dtype, "eigval_log")
- n_samples = X.shape[0]
-
- if n_samples == 1:
- mean = X[:1]
- if return_distances:
- return mean, torch.zeros(X.shape[:-2], dtype=X.dtype, device=X.device)
- return mean
-
- w = torch.ones((*X.shape[:-2], 1, 1), dtype=X.dtype, device=X.device)
- w = w / n_samples
- G = (X * w).sum(dim=0, keepdim=True)
-
- nu = 1.0
- tau = float("inf")
-
- for _ in range(max_iter):
- G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G)
- X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt)
- G_tangent = (X_tangent * w).sum(dim=0, keepdim=True)
-
- crit = torch.norm(G_tangent, p="fro", dim=(-2, -1)).max().item()
- if crit <= eps:
- break
-
- G = G_sqrt @ matrix_exp.apply(nu * G_tangent) @ G_sqrt
-
- h = nu * crit
- if h < tau:
- nu = 0.95 * nu
- tau = h
- else:
- nu = 0.5 * nu
-
- if nu <= eps:
- break
-
- if return_distances:
- G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G)
- X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt)
- G_tangent = (X_tangent * w).sum(dim=0, keepdim=True)
- distances = torch.norm(X_tangent - G_tangent, p="fro", dim=(-2, -1))
- return G, distances
-
- return G
-
-
######################################################################
# Loading the Dataset
# -------------------
@@ -179,12 +114,13 @@ def frechet_mean(X, max_iter=50, return_distances=False):
# - **Source domain**: Session A (training with labels)
# - **Target domain**: Session B (adaptation without labels)
#
-
from braindecode.datasets import create_from_X_y
from moabb.datasets import BNCI2015_001
from moabb.paradigms import MotorImagery
from sklearn.preprocessing import LabelEncoder
+from spd_learn.functional import frechet_mean
+
dataset = BNCI2015_001()
paradigm = MotorImagery(
diff --git a/examples/howto/README.txt b/examples/howto/README.txt
new file mode 100644
index 0000000..f213a19
--- /dev/null
+++ b/examples/howto/README.txt
@@ -0,0 +1,8 @@
+.. _howto_guides:
+
+How-to Guides
+=============
+
+Task-oriented guides for solving specific problems with SPD Learn. Each guide
+addresses one focused question and assumes you already understand the basics
+(see :ref:`tutorials` if you're just getting started).
diff --git a/examples/howto/plot_howto_add_batchnorm.py b/examples/howto/plot_howto_add_batchnorm.py
new file mode 100644
index 0000000..1fb9a3a
--- /dev/null
+++ b/examples/howto/plot_howto_add_batchnorm.py
@@ -0,0 +1,128 @@
+"""
+.. _howto-add-batchnorm:
+
+How to Add Batch Normalization to an SPDNet
+===========================================
+
+Insert Riemannian batch normalization into an existing SPDNet pipeline
+to stabilize training and improve convergence.
+
+**Prerequisites**: Familiarity with SPDNet building blocks
+(see :ref:`tutorial-building-blocks`).
+
+"""
+
+######################################################################
+# The Problem
+# -----------
+#
+# You have a working SPDNet but training is unstable or converges slowly.
+# Adding batch normalization after each ``BiMap`` layer can help.
+#
+
+import torch
+import torch.nn as nn
+
+from spd_learn.modules import BiMap, LogEig, ReEig, SPDBatchNormLie
+
+
+######################################################################
+# Step 1: Choose Your Normalization Layer
+# ----------------------------------------
+#
+# spd_learn provides three batch normalization modules:
+#
+# .. list-table::
+# :header-rows: 1
+# :widths: 30 70
+#
+# * - Module
+# - When to Use
+# * - :class:`~spd_learn.modules.SPDBatchNormMeanVar`
+# - Standard choice. AIRM Frechet mean + variance scaling.
+# * - :class:`~spd_learn.modules.SPDBatchNormLie`
+# - Multiple metrics (AIM, LEM, LCM). Based on Lie group
+# structure :cite:p:`chen2024liebn`.
+# * - :class:`~spd_learn.modules.SPDBatchNormMean`
+# - Mean-only centering (no variance scaling). Simplest option.
+#
+
+######################################################################
+# Step 2: Insert After BiMap, Before ReEig
+# -----------------------------------------
+#
+# The standard placement is ``BiMap -> BN -> ReEig``. The final BiMap
+# uses BN but skips ReEig:
+#
+
+dims = [64, 32, 16] # your SPD matrix dimensions
+layers = []
+for i in range(len(dims) - 1):
+ layers.append(BiMap(dims[i], dims[i + 1]))
+ layers.append(SPDBatchNormLie(dims[i + 1], metric="LEM"))
+ if i < len(dims) - 2: # no ReEig after last BiMap
+ layers.append(ReEig())
+
+features = nn.Sequential(*layers)
+print(features)
+
+######################################################################
+# Step 3: Complete Network
+# -------------------------
+#
+# Wrap the features with ``LogEig`` and a linear classifier:
+#
+
+
+class SPDNetWithBN(nn.Module):
+ """SPDNet with configurable batch normalization."""
+
+ def __init__(self, dims, n_classes, metric="LEM"):
+ super().__init__()
+ layers = []
+ for i in range(len(dims) - 1):
+ layers.append(BiMap(dims[i], dims[i + 1]))
+ layers.append(SPDBatchNormLie(dims[i + 1], metric=metric))
+ if i < len(dims) - 2:
+ layers.append(ReEig())
+ self.features = nn.Sequential(*layers)
+ self.logeig = LogEig(upper=False, flatten=True)
+ self.classifier = nn.Linear(dims[-1] ** 2, n_classes)
+
+ def forward(self, x):
+ return self.classifier(self.logeig(self.features(x)))
+
+
+model = SPDNetWithBN([64, 32, 16], n_classes=4, metric="LEM")
+
+######################################################################
+# Verify it works with a dummy forward pass:
+
+X = torch.randn(8, 64, 64)
+X = X @ X.mT + 0.01 * torch.eye(64)
+out = model(X)
+print(f"Input: {X.shape} -> Output: {out.shape}") # [8, 64, 64] -> [8, 4]
+
+######################################################################
+# Key Points
+# ----------
+#
+# - Place BN **after** ``BiMap`` and **before** ``ReEig``
+# - Use ``model.train()`` / ``model.eval()`` -- BN uses running stats
+# at inference time
+# - ``momentum=0.1`` (default) works well in most cases
+# - Consider ``float64`` for numerical stability with Riemannian operations
+#
+# .. seealso::
+#
+# - :ref:`tutorial-batch-normalization` -- Learn how BN works on SPD manifolds
+# - :ref:`howto-choose-metric` -- Choosing between AIM, LEM, and LCM
+# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference
+# - :class:`~spd_learn.modules.SPDBatchNormMeanVar` -- API reference
+#
+# References
+# ----------
+#
+# .. bibliography::
+# :filter: docname in docnames
+#
diff --git a/examples/howto/plot_howto_choose_metric.py b/examples/howto/plot_howto_choose_metric.py
new file mode 100644
index 0000000..533c313
--- /dev/null
+++ b/examples/howto/plot_howto_choose_metric.py
@@ -0,0 +1,267 @@
+"""
+.. _howto-choose-metric:
+
+How to Choose a Metric for Batch Normalization
+===============================================
+
+Select the right Riemannian metric for :class:`~spd_learn.modules.SPDBatchNormLie`.
+Each metric trades off speed, invariance, and numerical stability.
+
+**Prerequisites**: Familiarity with SPD batch normalization
+(see :ref:`tutorial-batch-normalization`).
+
+"""
+
+######################################################################
+# Quick Decision Guide
+# ---------------------
+#
+# .. list-table::
+# :header-rows: 1
+# :widths: 12 12 15 61
+#
+# * - Metric
+# - Speed
+# - Invariance
+# - Use When
+# * - **LEM**
+# - Fastest
+# - Orthogonal
+# - Default choice. Closed-form mean, good general performance.
+# * - **AIM**
+# - Slowest
+# - Full affine
+# - Data has varying scale (e.g., cross-subject EEG).
+# * - **LCM**
+# - Fast
+# - Lower-triangular
+# - Speed matters and you want Cholesky stability.
+#
+
+import time
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from pyriemann.datasets import make_gaussian_blobs
+
+from spd_learn.modules import SPDBatchNormLie
+
+
+torch.manual_seed(42)
+
+# Generate SPD data using pyriemann (2-class, 2*n_matrices total samples)
+n_matrices = 32
+n_dim = 8
+X_np, y = make_gaussian_blobs(
+ n_matrices=n_matrices,
+ n_dim=n_dim,
+ class_sep=1.5,
+ class_disp=0.5,
+ random_state=42,
+)
+X = torch.from_numpy(X_np).float() # shape: (64, 8, 8)
+
+metric_colors = {"LEM": "#2ecc71", "LCM": "#3498db", "AIM": "#e74c3c"}
+
+######################################################################
+# Comparing Forward Pass Speed
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# The cost depends on the metric's mean computation. AIM requires an
+# iterative Karcher mean, while LEM and LCM use closed-form means.
+# We use larger 32x32 matrices here so timing differences are visible.
+#
+
+n_bench = 32
+batch_size = 64
+A = torch.randn(batch_size, n_bench, n_bench)
+X_bench = A @ A.mT + 0.01 * torch.eye(n_bench)
+
+timings = {}
+for metric in ["LEM", "LCM", "AIM"]:
+ bn = SPDBatchNormLie(n_bench, metric=metric)
+ bn.train()
+ _ = bn(X_bench) # warmup
+ t0 = time.time()
+ for _ in range(20):
+ _ = bn(X_bench)
+ elapsed = (time.time() - t0) / 20
+ timings[metric] = elapsed * 1000
+ print(
+ f"{metric}: {elapsed * 1000:.1f} ms/batch ({n_bench}x{n_bench}, batch={batch_size})"
+ )
+
+fig, ax = plt.subplots(figsize=(6, 4))
+ax.bar(
+ timings.keys(),
+ timings.values(),
+ color=[metric_colors[m] for m in timings],
+)
+ax.set_ylabel("Time (ms)")
+ax.set_title("Forward Pass Speed by Metric")
+ax.grid(axis="y", alpha=0.3)
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Comparing Output Properties
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Each metric normalizes eigenvalues differently. Inspect the
+# eigenvalue distribution after normalization:
+#
+
+fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True)
+for ax, metric in zip(axes, ["LEM", "LCM", "AIM"]):
+ bn = SPDBatchNormLie(n_dim, metric=metric)
+ bn.train()
+ Y = bn(X)
+ eigvals = torch.linalg.eigvalsh(Y.detach())
+ ax.boxplot(
+ [eigvals[:, i].numpy() for i in range(n_dim)],
+ positions=range(n_dim),
+ )
+ ax.set_title(f"{metric}", color=metric_colors[metric], fontweight="bold")
+ ax.set_xlabel("Eigenvalue index")
+ ax.grid(True, alpha=0.3)
+
+axes[0].set_ylabel("Eigenvalue")
+plt.suptitle("Eigenvalue Distribution After Normalization", fontweight="bold")
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Effect of Each Metric on an SPD Matrix
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# To build visual intuition, let's examine how each metric transforms a
+# single SPD matrix. The top row shows the matrix entries as a heatmap
+# (before and after normalization), while the bottom row compares the
+# eigenvalue spectrum.
+#
+
+sample_idx = 0
+X_sample = X[sample_idx] # shape (8, 8)
+
+fig, axes = plt.subplots(2, 4, figsize=(16, 7))
+
+# Top row: matrix heatmaps
+im = axes[0, 0].imshow(X_sample.numpy(), cmap="RdBu_r", aspect="auto")
+axes[0, 0].set_title("Original", fontweight="bold")
+plt.colorbar(im, ax=axes[0, 0], shrink=0.8)
+
+normalized = {}
+for col, metric in enumerate(["LEM", "LCM", "AIM"], start=1):
+ bn = SPDBatchNormLie(n_dim, metric=metric)
+ bn.train()
+ Y = bn(X)
+ Y_sample = Y[sample_idx].detach().numpy()
+ normalized[metric] = Y_sample
+
+ ax = axes[0, col]
+ im = ax.imshow(Y_sample, cmap="RdBu_r", aspect="auto")
+ ax.set_title(f"After {metric}", fontweight="bold", color=metric_colors[metric])
+ plt.colorbar(im, ax=ax, shrink=0.8)
+
+# Bottom row: eigenvalue bar charts
+eigvals_orig = np.sort(np.linalg.eigvalsh(X_sample.numpy()))[::-1]
+axes[1, 0].bar(range(n_dim), eigvals_orig, color="gray", alpha=0.8)
+axes[1, 0].set_title("Eigenvalues", fontweight="bold")
+axes[1, 0].set_xlabel("Index")
+axes[1, 0].set_ylabel("Value")
+
+for col, metric in enumerate(["LEM", "LCM", "AIM"], start=1):
+ eigvals_after = np.sort(np.linalg.eigvalsh(normalized[metric]))[::-1]
+ ax = axes[1, col]
+ ax.bar(
+ range(n_dim),
+ eigvals_after,
+ color=metric_colors[metric],
+ alpha=0.8,
+ )
+ ax.set_title(
+ f"{metric} Eigenvalues", fontweight="bold", color=metric_colors[metric]
+ )
+ ax.set_xlabel("Index")
+
+plt.suptitle(
+ "Effect of Each Metric on a Single SPD Matrix",
+ fontweight="bold",
+ fontsize=13,
+)
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Tuning the Theta Parameter
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# The ``theta`` parameter controls the power deformation :math:`P^\theta`.
+# It adjusts how strongly scale differences are compressed or amplified:
+#
+# - ``theta=0.5``: Square root -- compresses large scale differences
+# - ``theta=1.0``: No deformation (default)
+# - ``theta=1.5``: Amplifies scale differences
+#
+# The violin plots below show how the eigenvalue distribution changes
+# with each theta value under the AIM metric:
+#
+
+thetas = [0.5, 1.0, 1.5]
+theta_colors = ["#3498db", "#9b59b6", "#f39c12"]
+
+fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True)
+
+for ax, theta, color in zip(axes, thetas, theta_colors):
+ bn = SPDBatchNormLie(n_dim, metric="AIM", theta=theta)
+ bn.train()
+ Y = bn(X)
+ eigvals = torch.linalg.eigvalsh(Y.detach()).numpy()
+
+ parts = ax.violinplot(
+ [eigvals[:, i] for i in range(n_dim)],
+ positions=range(n_dim),
+ showmedians=True,
+ showextrema=True,
+ )
+ for pc in parts["bodies"]:
+ pc.set_facecolor(color)
+ pc.set_alpha(0.6)
+ for key in ("cbars", "cmins", "cmaxes", "cmedians"):
+ parts[key].set_color(color)
+
+ ax.set_title(f"θ = {theta}", fontweight="bold", fontsize=12)
+ ax.set_xlabel("Eigenvalue index")
+ ax.grid(True, alpha=0.3)
+
+ print(
+ f"AIM (theta={theta}): eigval range [{eigvals.min():.3f}, {eigvals.max():.3f}]"
+ )
+
+axes[0].set_ylabel("Eigenvalue")
+plt.suptitle(
+ "Theta Parameter Effect on Eigenvalue Distribution (AIM Metric)",
+ fontweight="bold",
+)
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Recommendations
+# ----------------
+#
+# 1. **Start with LEM** -- fastest, closed-form mean, works well in most cases
+# 2. **Try AIM** if your data has varying scale across subjects or sessions
+# 3. **Use LCM** when you need speed similar to LEM with Cholesky stability
+# 4. **Tune theta** on a validation set when using AIM or LCM
+# (see :ref:`liebn-batch-normalization` for a multi-dataset benchmark)
+#
+# .. seealso::
+#
+# - :ref:`tutorial-batch-normalization` -- Detailed tutorial on SPD BN
+# - :ref:`howto-add-batchnorm` -- How to add BN to your pipeline
+# - :ref:`liebn-batch-normalization` -- Full benchmark reproduction
+# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference
+#
diff --git a/examples/tutorials/tutorial_05_batch_normalization.py b/examples/tutorials/tutorial_05_batch_normalization.py
new file mode 100644
index 0000000..9f01ebe
--- /dev/null
+++ b/examples/tutorials/tutorial_05_batch_normalization.py
@@ -0,0 +1,374 @@
+"""
+.. _tutorial-batch-normalization:
+
+Batch Normalization on SPD Manifolds
+=====================================
+
+This tutorial teaches how batch normalization works on SPD matrices and
+why it matters for training SPD neural networks. You will train a simple
+SPDNet with and without normalization, observe the impact, and compare
+different Riemannian metrics.
+
+By the end, you will understand:
+
+- Why standard Euclidean batch normalization doesn't apply to SPD matrices
+- How Riemannian and Lie group batch normalization work
+- When to use each metric (AIM, LEM, LCM)
+
+.. contents:: This tutorial covers:
+ :local:
+ :depth: 2
+
+"""
+
+######################################################################
+# Setup and Imports
+# -----------------
+#
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+
+from spd_learn.modules import (
+ BiMap,
+ LogEig,
+ ReEig,
+ SPDBatchNormLie,
+ SPDBatchNormMeanVar,
+)
+
+
+torch.manual_seed(42)
+np.random.seed(42)
+
+######################################################################
+# Creating Synthetic SPD Data
+# ----------------------------
+#
+# We generate a 3-class classification problem with 8x8 SPD matrices.
+# Each class has a distinct eigenvalue profile, simulating how real-world
+# signals (EEG, radar) produce covariance matrices with different spectral
+# characteristics.
+#
+
+
+def make_spd_dataset(n_samples_per_class=100, n=8, n_classes=3, seed=42):
+ """Generate synthetic SPD classification data."""
+ rng = np.random.RandomState(seed)
+ X_list, y_list = [], []
+ for c in range(n_classes):
+ eigvals = np.exp(rng.randn(n) * 0.5 + c * 0.3)
+ for _ in range(n_samples_per_class):
+ Q, _ = np.linalg.qr(rng.randn(n, n))
+ S = Q @ np.diag(eigvals + rng.rand(n) * 0.1) @ Q.T
+ S = (S + S.T) / 2
+ X_list.append(S)
+ y_list.append(c)
+ X = torch.from_numpy(np.stack(X_list)).float()
+ y = torch.from_numpy(np.array(y_list))
+ perm = torch.randperm(len(X), generator=torch.Generator().manual_seed(seed))
+ return X[perm], y[perm]
+
+
+X, y = make_spd_dataset()
+n_train = int(0.7 * len(X))
+X_train, y_train = X[:n_train], y[:n_train]
+X_test, y_test = X[n_train:], y[n_train:]
+print(
+ f"Dataset: {len(X)} samples ({n_train} train, {len(X) - n_train} test), "
+ f"{len(torch.unique(y))} classes, {X.shape[1]}x{X.shape[2]} SPD matrices"
+)
+
+
+######################################################################
+# A Simple SPDNet
+# ----------------
+#
+# We define a minimal SPDNet with one BiMap layer, optional batch
+# normalization, ReEig activation, and a LogEig + linear classifier.
+#
+
+
+class SimpleSPDNet(nn.Module):
+ """Minimal SPDNet with optional batch normalization."""
+
+ def __init__(self, n_in, n_out, n_classes, bn=None):
+ super().__init__()
+ self.bimap = BiMap(n_in, n_out)
+ self.bn = bn
+ self.reeig = ReEig()
+ self.logeig = LogEig(upper=False, flatten=True)
+ self.classifier = nn.Linear(n_out**2, n_classes)
+
+ def forward(self, x):
+ x = self.bimap(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ x = self.reeig(x)
+ x = self.logeig(x)
+ return self.classifier(x)
+
+
+def train_model(model, X_train, y_train, X_test, y_test, epochs=150, lr=5e-3):
+ """Train a model and record loss and accuracy curves."""
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ criterion = nn.CrossEntropyLoss()
+ train_losses, test_accs = [], []
+
+ for epoch in range(epochs):
+ model.train()
+ optimizer.zero_grad(set_to_none=True)
+ loss = criterion(model(X_train), y_train)
+ loss.backward()
+ optimizer.step()
+ train_losses.append(loss.item())
+
+ if (epoch + 1) % 5 == 0:
+ model.eval()
+ with torch.no_grad():
+ acc = (model(X_test).argmax(1) == y_test).float().mean().item()
+ test_accs.append((epoch + 1, acc))
+
+ return train_losses, test_accs
+
+
+######################################################################
+# Baseline: No Batch Normalization
+# ----------------------------------
+#
+# First, let's train without any normalization. Notice the loss curve
+# and final accuracy.
+#
+
+torch.manual_seed(42)
+model_none = SimpleSPDNet(8, 4, 3, bn=None)
+losses_none, accs_none = train_model(model_none, X_train, y_train, X_test, y_test)
+print(f"No BN: final accuracy = {accs_none[-1][1]:.1%}")
+
+
+######################################################################
+# Adding Riemannian Batch Normalization
+# ----------------------------------------
+#
+# :class:`~spd_learn.modules.SPDBatchNormMeanVar` normalizes using the
+# Frechet mean and dispersion under the Affine-Invariant Riemannian
+# Metric (AIRM). This is the Riemannian analogue of standard batch
+# normalization.
+#
+
+torch.manual_seed(42)
+model_rbn = SimpleSPDNet(8, 4, 3, bn=SPDBatchNormMeanVar(4, momentum=0.1))
+losses_rbn, accs_rbn = train_model(model_rbn, X_train, y_train, X_test, y_test)
+print(f"SPDBatchNormMeanVar: final accuracy = {accs_rbn[-1][1]:.1%}")
+
+######################################################################
+# Notice the improvement! Batch normalization stabilizes the loss
+# trajectory and allows the network to converge to a better solution.
+#
+
+######################################################################
+# Lie Group Batch Normalization
+# ------------------------------
+#
+# :class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn`
+# exploits the Lie group structure of :math:`\spd`. Unlike
+# ``SPDBatchNormMeanVar`` which only supports AIRM, LieBN supports three
+# Riemannian metrics:
+#
+# - **AIM** (Affine-Invariant Metric): Iterative Karcher mean, full
+# affine invariance.
+# - **LEM** (Log-Euclidean Metric): Closed-form mean, fast computation.
+# - **LCM** (Log-Cholesky Metric): Cholesky-based, numerically stable.
+#
+# Let's try all three:
+#
+
+liebn_results = {}
+for metric in ["AIM", "LEM", "LCM"]:
+ torch.manual_seed(42)
+ model = SimpleSPDNet(8, 4, 3, bn=SPDBatchNormLie(4, metric=metric))
+ losses, accs = train_model(model, X_train, y_train, X_test, y_test)
+ liebn_results[metric] = (losses, accs)
+ print(f"LieBN ({metric}): final accuracy = {accs[-1][1]:.1%}")
+
+
+######################################################################
+# Comparing Training Dynamics
+# ----------------------------
+#
+# Let's visualize how each normalization strategy affects convergence.
+#
+
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
+
+# Loss curves
+ax1.plot(losses_none, label="No BN", alpha=0.7, color="gray")
+ax1.plot(losses_rbn, label="SPDBatchNormMeanVar", alpha=0.7, color="black")
+colors = {"AIM": "#e74c3c", "LEM": "#2ecc71", "LCM": "#3498db"}
+for metric, (losses, _) in liebn_results.items():
+ ax1.plot(losses, label=f"LieBN ({metric})", alpha=0.7, color=colors[metric])
+ax1.set_xlabel("Epoch")
+ax1.set_ylabel("Training Loss")
+ax1.set_title("Loss Convergence")
+ax1.legend()
+ax1.grid(True, alpha=0.3)
+
+# Accuracy curves
+for label, accs, color in [
+ ("No BN", accs_none, "gray"),
+ ("SPDBatchNormMeanVar", accs_rbn, "black"),
+]:
+ epochs, vals = zip(*accs)
+ ax2.plot(epochs, vals, "o-", label=label, color=color, markersize=3)
+for metric, (_, accs) in liebn_results.items():
+ epochs, vals = zip(*accs)
+ ax2.plot(
+ epochs,
+ vals,
+ "o-",
+ label=f"LieBN ({metric})",
+ color=colors[metric],
+ markersize=3,
+ )
+ax2.set_xlabel("Epoch")
+ax2.set_ylabel("Test Accuracy")
+ax2.set_title("Accuracy Over Training")
+ax2.legend()
+ax2.grid(True, alpha=0.3)
+ax2.set_ylim(0, 1.05)
+
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Inspecting the LieBN Pipeline
+# --------------------------------
+#
+# LieBN normalizes SPD matrices through five geometric steps:
+#
+# 1. **Deformation** -- map SPD matrices to a codomain via the chosen
+# metric (e.g., matrix log for LEM, Cholesky + log-diag for LCM)
+# 2. **Centering** -- translate the batch to zero/identity mean
+# 3. **Scaling** -- normalize variance by a learnable dispersion parameter
+# 4. **Biasing** -- translate by a learnable location parameter
+# 5. **Inverse Deformation** -- map back to the SPD manifold
+#
+# The three metrics differ in *how* they perform deformation and centering:
+#
+# .. list-table::
+# :header-rows: 1
+# :widths: 15 25 25 25
+#
+# * - Metric
+# - Deformation
+# - Mean
+# - Group Action
+# * - **LEM**
+# - :math:`\log(X)`
+# - Euclidean (closed-form)
+# - Additive
+# * - **LCM**
+# - Cholesky + log-diag
+# - Euclidean (closed-form)
+# - Additive
+# * - **AIM**
+# - :math:`X^\theta`
+# - Karcher (iterative)
+# - Cholesky congruence
+#
+# Let's watch the running variance converge during training:
+#
+
+torch.manual_seed(42)
+A = torch.randn(64, 8, 8)
+X_demo = A @ A.mT + 0.1 * torch.eye(8)
+
+fig, ax = plt.subplots(figsize=(8, 4))
+for metric in ["AIM", "LEM", "LCM"]:
+ bn = SPDBatchNormLie(8, metric=metric, momentum=0.1)
+ bn.train()
+ variances = []
+ for epoch in range(50):
+ perm = torch.randperm(len(X_demo))
+ for i in range(0, len(X_demo), 16):
+ batch = X_demo[perm[i : i + 16]]
+ if batch.shape[0] < 2:
+ continue
+ _ = bn(batch)
+ variances.append(bn.running_var.item())
+ ax.plot(variances, label=metric, color=colors[metric])
+
+ax.set_xlabel("Epoch")
+ax.set_ylabel("Running Variance")
+ax.set_title("Running Variance Convergence Across Metrics")
+ax.legend()
+ax.grid(True, alpha=0.3)
+plt.tight_layout()
+plt.show()
+
+######################################################################
+# Notice that all three metrics converge, but LEM and LCM converge
+# faster thanks to their closed-form mean computation.
+#
+
+######################################################################
+# Verifying SPD Output and Gradient Flow
+# ----------------------------------------
+#
+# A critical property: batch normalization must produce valid SPD output
+# with flowing gradients. Let's verify:
+#
+
+torch.manual_seed(42)
+A = torch.randn(8, 4, 4)
+X_check = (A @ A.mT + 0.1 * torch.eye(4)).requires_grad_(True)
+
+for metric in ["AIM", "LEM", "LCM"]:
+ bn = SPDBatchNormLie(4, metric=metric)
+ bn.train()
+ out = bn(X_check)
+ loss = (out * out).sum()
+ loss.backward()
+ eigvals = torch.linalg.eigvalsh(out.detach())
+ print(
+ f"{metric}: min_eigval={eigvals.min():.2e}, grad_norm={X_check.grad.norm():.4f}"
+ )
+ X_check.grad = None
+
+######################################################################
+# All eigenvalues are positive (valid SPD) and gradients flow
+# correctly through the normalization layer.
+#
+
+######################################################################
+# Summary
+# -------
+#
+# In this tutorial you learned:
+#
+# - **Why**: Batch normalization stabilizes SPDNet training by
+# normalizing the distribution of SPD activations
+# - **How**: Riemannian BN uses the Frechet mean and variance;
+# Lie group BN generalizes this to multiple metrics
+# - **Which metric**: LEM for speed, AIM for invariance, LCM for
+# Cholesky stability
+#
+# Next steps:
+#
+# .. seealso::
+#
+# - :ref:`howto-add-batchnorm` -- Add BN to an existing pipeline
+# - :ref:`howto-choose-metric` -- Decision guide for metric selection
+# - :ref:`liebn-batch-normalization` -- Full benchmark reproduction
+# across HDM05, Radar, and AFEW datasets
+# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference
+# - :class:`~spd_learn.modules.SPDBatchNormMeanVar` -- API reference
+#
+# References
+# ----------
+#
+# .. bibliography::
+# :filter: docname in docnames
+#
diff --git a/spd_learn/__init__.py b/spd_learn/__init__.py
index 9b68aad..9a02dcf 100644
--- a/spd_learn/__init__.py
+++ b/spd_learn/__init__.py
@@ -22,6 +22,7 @@
PatchEmbeddingLayer,
ReEig,
Shrinkage,
+ SPDBatchNormLie,
SPDBatchNormMean,
SPDBatchNormMeanVar,
SPDDropout,
@@ -50,13 +51,14 @@
"BatchReNorm",
"BiMap",
"BiMapIncreaseDim",
- "SPDBatchNormMean",
"CovLayer",
"ExpEig",
"LogEig",
"PatchEmbeddingLayer",
"ReEig",
"Shrinkage",
+ "SPDBatchNormLie",
+ "SPDBatchNormMean",
"SPDBatchNormMeanVar",
"SPDDropout",
"TraceNorm",
diff --git a/spd_learn/functional/__init__.py b/spd_learn/functional/__init__.py
index 97b4f00..0182bac 100644
--- a/spd_learn/functional/__init__.py
+++ b/spd_learn/functional/__init__.py
@@ -16,8 +16,11 @@
from .autograd import modeig_backward, modeig_forward
from .batchnorm import (
+ frechet_mean,
karcher_mean_iteration,
+ lie_group_variance,
spd_centering,
+ spd_cholesky_congruence,
spd_rebiasing,
tangent_space_variance,
)
@@ -155,8 +158,11 @@
"ledoit_wolf",
"shrinkage_covariance",
# Batch normalization
+ "frechet_mean",
"karcher_mean_iteration",
+ "lie_group_variance",
"spd_centering",
+ "spd_cholesky_congruence",
"spd_rebiasing",
"tangent_space_variance",
# Bilinear operations
diff --git a/spd_learn/functional/batchnorm.py b/spd_learn/functional/batchnorm.py
index 16c4851..71eae7e 100644
--- a/spd_learn/functional/batchnorm.py
+++ b/spd_learn/functional/batchnorm.py
@@ -1,5 +1,6 @@
# Copyright (c) 2024-now SPD Learn Developers
# SPDX-License-Identifier: BSD-3-Clause
+
"""Functional operations for SPD batch normalization.
This module provides stateless mathematical operations for Riemannian batch
@@ -8,12 +9,18 @@
Functions
---------
+frechet_mean
+ Fréchet mean of SPD matrices under the AIRM via Karcher flow.
karcher_mean_iteration
Single iteration of the Karcher (Fréchet) mean algorithm.
spd_centering
Center SPD matrices around a given mean via congruence transformation.
+spd_cholesky_congruence
+ Congruence transformation using the Cholesky factor of an SPD matrix.
tangent_space_variance
Compute variance of SPD matrices in the tangent space.
+lie_group_variance
+ Fréchet variance under a Lie group structure on the SPD manifold.
See Also
--------
@@ -21,16 +28,20 @@
:class:`~spd_learn.modules.SPDBatchNormMeanVar` : Full Riemannian batch normalization.
"""
+from typing import Optional, Tuple, Union
+
import torch
from .core import matrix_exp, matrix_log, matrix_sqrt_inv
+from .utils import ensure_sym
def karcher_mean_iteration(
X: torch.Tensor,
current_mean: torch.Tensor,
detach: bool = True,
-) -> torch.Tensor:
+ return_tangent: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Perform one iteration of the Karcher mean algorithm.
The Karcher (Fréchet) mean on the SPD manifold is the minimizer of the sum
@@ -54,11 +65,15 @@ def karcher_mean_iteration(
If True, detaches ``current_mean`` from the computational graph before
computing the update. Set to False when gradients with respect to the
mean are needed.
+ return_tangent : bool, default=False
+ If True, also returns the mean tangent update used in this Karcher step.
Returns
-------
- torch.Tensor
- Updated Karcher mean estimate with shape `(1, ..., n, n)`.
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor]
+ Updated Karcher mean estimate with shape `(1, ..., n, n)`. When
+ ``return_tangent=True``, also returns the mean tangent update with the
+ same shape.
Notes
-----
@@ -85,9 +100,85 @@ def karcher_mean_iteration(
mean_tangent = X_tangent.mean(dim=0, keepdim=True)
# Map back to manifold
new_mean = mean_sqrt @ matrix_exp.apply(mean_tangent) @ mean_sqrt
+ if return_tangent:
+ return new_mean, mean_tangent
return new_mean
+def frechet_mean(
+ X: torch.Tensor,
+ max_iter: int = 1,
+ weights: Optional[torch.Tensor] = None,
+ return_distances: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ r"""Fréchet mean of SPD matrices under the AIRM via Karcher flow.
+
+ Computes the minimizer of the sum of squared geodesic distances:
+
+ .. math::
+
+ \bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n}
+ \sum_{i=1}^{N} w_i \, d_{\text{AIRM}}^2(G, X_i)
+
+ using iterative Karcher flow initialized from the (weighted) Euclidean mean.
+
+ Parameters
+ ----------
+ X : torch.Tensor
+ Batch of SPD matrices with shape ``(batch_size, ..., n, n)``.
+ max_iter : int, default=1
+ Number of Karcher flow iterations. A single iteration is often
+ sufficient for batch normalization; use more (e.g. 50) when a
+ high-accuracy mean is needed.
+ weights : torch.Tensor, optional
+ Per-sample weights with shape broadcastable to ``X``. When ``None``,
+ uniform weights ``1/N`` are used.
+ return_distances : bool, default=False
+ If True, also returns the geodesic distances from each sample to
+ the mean.
+
+ Returns
+ -------
+ mean : torch.Tensor
+ Fréchet mean with shape ``(1, ..., n, n)``.
+ distances : torch.Tensor
+ Only returned when ``return_distances=True``. Geodesic distances
+ from each sample to the mean, with shape ``(batch_size, ...)``.
+
+ See Also
+ --------
+ :func:`karcher_mean_iteration` : Single Karcher step (lower-level).
+ :func:`~spd_learn.functional.airm_distance` : Pairwise AIRM distance.
+
+ References
+ ----------
+ See :cite:p:`pennec2006riemannian` for details on Karcher mean computation.
+ """
+ batch = X.detach()
+
+ if weights is None:
+ mean = batch.mean(dim=0, keepdim=True)
+ else:
+ mean = (batch * weights).sum(dim=0, keepdim=True)
+
+ for _ in range(max_iter):
+ mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean)
+ X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt)
+ if weights is None:
+ mean_tangent = X_tangent.mean(dim=0, keepdim=True)
+ else:
+ mean_tangent = (X_tangent * weights).sum(dim=0, keepdim=True)
+ mean = mean_sqrt @ matrix_exp.apply(mean_tangent) @ mean_sqrt
+
+ if return_distances:
+ mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean)
+ X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt)
+ distances = torch.norm(X_tangent, p="fro", dim=(-2, -1))
+ return mean, distances
+
+ return mean
+
+
def spd_centering(
X: torch.Tensor,
mean_invsqrt: torch.Tensor,
@@ -209,9 +300,126 @@ def tangent_space_variance(
return variance
+def spd_cholesky_congruence(
+ X: torch.Tensor,
+ P: torch.Tensor,
+ inverse: bool = False,
+) -> torch.Tensor:
+ r"""Congruence transformation using the Cholesky factor of an SPD matrix.
+
+ Given an SPD matrix :math:`P = LL^T`, applies:
+
+ .. math::
+
+ \text{forward: } Y = LXL^T, \qquad
+ \text{inverse: } Y = L^{-1}X L^{-T}
+
+ This implements the Lie group action of ``GL(n)`` on the SPD manifold and
+ is used for centering and biasing under the affine-invariant metric.
+
+ Parameters
+ ----------
+ X : torch.Tensor
+ Batch of SPD matrices with shape `(..., n, n)`.
+ P : torch.Tensor
+ SPD matrix whose Cholesky factor defines the transformation,
+ with shape broadcastable to ``X``.
+ inverse : bool, default=False
+ If True, applies the inverse congruence :math:`L^{-1}X L^{-T}`.
+
+ Returns
+ -------
+ torch.Tensor
+ Transformed SPD matrices with the same shape as ``X``.
+
+ See Also
+ --------
+ :func:`spd_centering` : Eigendecomposition-based centering (uses :math:`M^{-1/2}`).
+ """
+ L = torch.linalg.cholesky(P)
+ if inverse:
+ Y = torch.linalg.solve_triangular(L, X, upper=False)
+ return ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT)
+ return ensure_sym(L @ X @ L.mT)
+
+
+def lie_group_variance(
+ X_centered: torch.Tensor,
+ metric: str,
+ alpha: float = 1.0,
+ beta: float = 0.0,
+ theta: float = 1.0,
+) -> torch.Tensor:
+ r"""Fréchet variance under a Lie group structure on the SPD manifold.
+
+ Computes the scalar dispersion of centered data in the Lie algebra,
+ using the bi-invariant distance of Chen et al. :cite:p:`chen2024liebn`:
+
+ .. math::
+
+ \sigma^2 = \frac{1}{N} \sum_i
+ \bigl(\alpha \lVert V_i \rVert_F^2 + \beta \, g(V_i)^2\bigr)
+ \;/\; \theta^2
+
+ where the auxiliary term :math:`g` depends on the metric:
+
+ - **AIM**: :math:`V_i = \log(X_i)`, :math:`g(V) = \log\det(X)`
+ - **LEM**: :math:`V_i = X_i` (already in log space), :math:`g(V) = \operatorname{tr}(V)`,
+ no :math:`\theta` scaling
+ - **LCM**: same as LEM but with :math:`\theta` scaling
+
+ Parameters
+ ----------
+ X_centered : torch.Tensor
+ Centered data in the Lie algebra with shape `(batch_size, ..., n, n)`.
+ For AIM these are SPD matrices (centered around identity); for LEM/LCM
+ these are symmetric / lower-triangular matrices.
+ metric : {"AIM", "LEM", "LCM"}
+ Lie group structure.
+ alpha : float, default=1.0
+ Frobenius-norm weight.
+ beta : float, default=0.0
+ Trace / log-determinant weight.
+ theta : float, default=1.0
+ Power deformation parameter.
+
+ Returns
+ -------
+ torch.Tensor
+ Scalar variance (0-d tensor).
+
+ See Also
+ --------
+ :func:`tangent_space_variance` : Unweighted tangent-space dispersion used
+ by :class:`~spd_learn.modules.SPDBatchNormMeanVar`.
+ """
+ if metric == "AIM":
+ logX = matrix_log.apply(X_centered)
+ frob_sq = (logX * logX).sum(dim=(-2, -1))
+ dists = alpha * frob_sq
+ if beta != 0:
+ dists = dists + beta * torch.logdet(X_centered).square()
+ return dists.mean() / (theta**2)
+
+ frob_sq = (X_centered * X_centered).sum(dim=(-2, -1))
+ dists = alpha * frob_sq
+ if beta != 0:
+ trace = X_centered.diagonal(dim1=-2, dim2=-1).sum(dim=-1)
+ dists = dists + beta * trace.square()
+ var = dists.mean()
+ if metric == "LCM":
+ var = var / (theta**2)
+ elif metric != "LEM":
+ raise ValueError(f"metric must be 'AIM', 'LEM', or 'LCM', got '{metric}'")
+ return var
+
+
__all__ = [
+ "frechet_mean",
"karcher_mean_iteration",
+ "lie_group_variance",
"spd_centering",
+ "spd_cholesky_congruence",
"spd_rebiasing",
"tangent_space_variance",
]
diff --git a/spd_learn/modules/__init__.py b/spd_learn/modules/__init__.py
index 36a5b85..8183fd5 100644
--- a/spd_learn/modules/__init__.py
+++ b/spd_learn/modules/__init__.py
@@ -4,6 +4,7 @@
from .bilinear import BiMap, BiMapIncreaseDim
from .covariance import CovLayer
from .dropout import SPDDropout
+from .liebn import SPDBatchNormLie
from .manifold import PositiveDefiniteScalar, SymmetricPositiveDefinite
from .modeig import ExpEig, LogEig, ReEig
from .regularize import Shrinkage, TraceNorm
@@ -29,6 +30,7 @@
"SPDBatchNormMean",
"BatchReNorm",
"SPDBatchNormMeanVar",
+ "SPDBatchNormLie",
# dropout
"SPDDropout",
# residual
diff --git a/spd_learn/modules/batchnorm.py b/spd_learn/modules/batchnorm.py
index c87871c..e2c051c 100644
--- a/spd_learn/modules/batchnorm.py
+++ b/spd_learn/modules/batchnorm.py
@@ -13,7 +13,7 @@
matrix_sqrt,
)
from ..functional.batchnorm import (
- karcher_mean_iteration,
+ frechet_mean,
spd_centering,
spd_rebiasing,
tangent_space_variance,
@@ -194,10 +194,7 @@ def forward(self, input):
"""
if self.training:
- mean = input.mean(dim=0, keepdim=True)
- if input.shape[0] > 1:
- for _ in range(self.n_iter):
- mean = karcher_mean_iteration(input, mean)
+ mean = frechet_mean(input, max_iter=self.n_iter)
with torch.no_grad():
self.running_mean = airm_geodesic(
self.running_mean, mean, self.momentum
@@ -478,14 +475,8 @@ def forward(self, input):
Normalized tensor of the same shape as the input.
"""
- n_samples = input.shape[0]
if self.training:
- # Kobler et al. SPDMBN/SPDBN: estimate batch Fréchet mean via Karcher step
- batch_mean = input.mean(dim=0, keepdim=True)
- if n_samples > 1:
- for _ in range(self.n_iter):
- # Kobler et al. (Eq. 4): P2 L132-145; Karcher flow note: P2 L163-165
- batch_mean = karcher_mean_iteration(input, batch_mean)
+ batch_mean = frechet_mean(input, max_iter=self.n_iter)
# Scalar dispersion: mean squared Frobenius norm of log at the mean (a single scalar, not variance matrix)
mean_inv_sqrt = matrix_inv_sqrt.apply(batch_mean)
diff --git a/spd_learn/modules/liebn.py b/spd_learn/modules/liebn.py
new file mode 100644
index 0000000..1c8fdc0
--- /dev/null
+++ b/spd_learn/modules/liebn.py
@@ -0,0 +1,380 @@
+# Copyright (c) 2024-now SPD Learn Developers
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Lie Group Batch Normalization for SPD matrices.
+
+This module implements SPDBatchNormLie based on:
+Ziheng Chen, Yue Song, Yunmei Liu, and Nicu Sebe,
+"A Lie Group Approach to Riemannian Batch Normalization," ICLR 2024.
+
+The implementation is integrated into ``spd_learn`` from the original LieBN
+repository: https://github.com/GitZH-Chen/LieBN/tree/main/LieBN
+"""
+
+import torch
+
+from torch import nn
+from torch.nn.utils.parametrize import register_parametrization
+
+from ..functional import (
+ airm_geodesic,
+ ensure_sym,
+ matrix_exp,
+ matrix_inv_sqrt,
+ matrix_log,
+ matrix_power,
+ matrix_sqrt,
+)
+from ..functional.batchnorm import (
+ frechet_mean,
+ lie_group_variance,
+ spd_centering,
+ spd_cholesky_congruence,
+ spd_rebiasing,
+)
+from ..functional.metrics import (
+ cholesky_exp,
+ cholesky_log,
+ log_euclidean_scalar_multiply,
+)
+from .manifold import PositiveDefiniteScalar, SymmetricPositiveDefinite
+
+
+class SPDBatchNormLie(nn.Module):
+ r"""Lie Group Batch Normalization for SPD matrices.
+
+ Implements the LieBN framework :cite:p:`chen2024liebn` for SPD manifolds.
+ Unlike :class:`SPDBatchNormMeanVar`, which normalizes under a single
+ Riemannian metric (AIRM), this layer exploits the **Lie group structure**
+ of three classical SPD geometries to define centering, scaling, and biasing
+ as group-theoretic operations with formal statistical guarantees.
+
+ **Algorithm.**
+ Given a batch :math:`\{P_i\}_{i=1}^N \subset \mathcal{S}_{++}^n`, the
+ forward pass applies three steps in the Lie algebra selected by ``metric``:
+
+ 1. **Centering** -- translate the batch mean :math:`M` to the group
+ identity :math:`E` via the inverse left translation:
+
+ .. math::
+
+ \bar{P}_i = L_{M_\odot^{-1}}(P_i)
+
+ 2. **Scaling** -- normalize the Fréchet variance :math:`v^2` with a
+ learnable shift :math:`s \in \mathbb{R}_{>0}`:
+
+ .. math::
+
+ \hat{P}_i = \operatorname{Exp}_E
+ \!\left[\frac{s}{\sqrt{v^2 + \epsilon}}\,
+ \operatorname{Log}_E(\bar{P}_i)\right]
+
+ 3. **Biasing** -- translate to the learnable SPD parameter :math:`B`:
+
+ .. math::
+
+ \tilde{P}_i = L_B(\hat{P}_i)
+
+ **Theoretical guarantees** (Proposition 4.2 of the paper):
+
+ * *Mean control*: after centering and biasing with :math:`B = E`,
+ the Fréchet mean of the output batch equals :math:`E`.
+ * *Variance control*: after scaling, the output dispersion satisfies
+ :math:`\sum_i w_i\,d^2(\hat{P}_i, E) = s^2`.
+
+ **Supported metrics.**
+ The ``metric`` parameter selects one of three Lie group structures, each
+ inducing a family of parameterized metrics via the power deformation
+ :math:`\mathrm{P}_\theta`. The table below summarizes how each step is
+ realized (see Table 2 in :cite:p:`chen2024liebn`):
+
+ .. list-table::
+ :header-rows: 1
+ :widths: 25 25 25 25
+
+ * - Operation
+ - :math:`(\theta,\alpha,\beta)`-AIM
+ - :math:`(\alpha,\beta)`-LEM
+ - :math:`\theta`-LCM
+ * - Pullback map
+ - :math:`\mathrm{P}_\theta`
+ - :math:`\operatorname{mlog}`
+ - :math:`\psi_{\mathrm{LC}} \circ \mathrm{P}_\theta`
+ * - Left translation :math:`L_Q(P)`
+ - :math:`Q^{1/2} P\, Q^{1/2}`
+ - :math:`P + Q`
+ - :math:`P + Q`
+ * - Scaling
+ - :math:`\operatorname{Exp}_I[s\,\operatorname{Log}_I(P)]`
+ - :math:`s \cdot P`
+ - :math:`s \cdot P`
+ * - Fréchet mean
+ - Karcher flow
+ - Arithmetic mean
+ - Arithmetic mean
+ * - Running mean update
+ - AIRM geodesic
+ - Linear interpolation
+ - Linear interpolation
+
+ **Bi-invariant distance.**
+ The Fréchet variance uses the :math:`(\alpha, \beta)` bi-invariant metric
+ (Definition 3 and Eq. 3 of the paper):
+
+ .. math::
+
+ d^2(P, Q) = \alpha \lVert V \rVert_F^2
+ + \beta \, g(V)^2
+
+ where :math:`V` is the tangent representation (log-map) and
+ :math:`g(V) = \log\det(P)` for AIM or :math:`\operatorname{tr}(V)`
+ for LEM/LCM. The variance is normalized by :math:`\theta^2` for AIM
+ and LCM.
+
+ Parameters
+ ----------
+ num_features : int
+ Size of the SPD matrices (:math:`n \times n`).
+ metric : {"AIM", "LEM", "LCM"}, default="AIM"
+ Lie group invariant metric.
+ theta : float, default=1.0
+ Power deformation parameter :math:`\theta`. When
+ :math:`\theta = 1`, no deformation is applied.
+ alpha : float, default=1.0
+ Frobenius norm weight :math:`\alpha` in the bi-invariant distance.
+ beta : float, default=0.0
+ Trace / log-determinant weight :math:`\beta` in the bi-invariant
+ distance. Must satisfy :math:`\min(\alpha, \alpha + n\beta) > 0`.
+ momentum : float, default=0.1
+ Momentum :math:`\gamma` for exponential moving average of running
+ statistics.
+ eps : float, default=1e-5
+ Numerical stability constant :math:`\epsilon` added to the variance
+ before taking the square root.
+ n_iter : int, default=1
+ Number of Karcher flow iterations for the AIM Fréchet mean.
+ Ignored by LEM and LCM (which use arithmetic means).
+ congruence : {"cholesky", "eig"}, default="cholesky"
+ Implementation of the AIM congruence action (centering/biasing).
+ ``"cholesky"`` uses the Cholesky factor :math:`L` of :math:`P` to
+ compute :math:`L X L^\top` (as in the original LieBN paper).
+ ``"eig"`` uses eigendecomposition-based :math:`M^{-1/2} X M^{-1/2}`
+ (matching :func:`~spd_learn.functional.spd_centering`).
+ Both are mathematically equivalent; Cholesky is typically faster,
+ while eigendecomposition reuses the infrastructure of
+ :class:`~spd_learn.modules.SPDBatchNormMeanVar`.
+ Only affects the AIM metric.
+ device : torch.device or str, optional
+ Device on which to create parameters and buffers.
+ dtype : torch.dtype, optional
+ Data type of parameters and buffers.
+
+ Attributes
+ ----------
+ bias : nn.Parameter
+ Learnable SPD bias matrix :math:`B \in \mathcal{S}_{++}^n`,
+ parametrized via :class:`~spd_learn.modules.SymmetricPositiveDefinite`.
+ Initialized to the identity.
+ shift : nn.Parameter
+ Learnable positive scalar :math:`s > 0`,
+ parametrized via :class:`~spd_learn.modules.PositiveDefiniteScalar`.
+ Initialized to 1.
+ running_mean : torch.Tensor
+ Exponential moving average of the batch Fréchet mean.
+ running_var : torch.Tensor
+ Exponential moving average of the batch variance.
+
+ See Also
+ --------
+ :class:`SPDBatchNormMean` :
+ Mean-only Riemannian batch normalization (AIRM centering without
+ variance normalization) :cite:p:`brooks2019riemannian`.
+ :class:`SPDBatchNormMeanVar` :
+ Full Riemannian batch normalization under the AIRM
+ :cite:p:`kobler2022spd`.
+ :func:`~spd_learn.functional.frechet_mean` :
+ Fréchet mean via Karcher flow (used internally for AIM).
+ :func:`~spd_learn.functional.lie_group_variance` :
+ Bi-invariant Fréchet variance computation.
+
+ References
+ ----------
+ .. bibliography::
+ :filter: key == "chen2024liebn"
+
+ Examples
+ --------
+ >>> import torch
+ >>> from spd_learn.modules import SPDBatchNormLie
+ >>> bn = SPDBatchNormLie(num_features=4, metric="AIM")
+ >>> X = torch.randn(8, 4, 4, dtype=torch.float64)
+ >>> X = X @ X.mT + 0.1 * torch.eye(4, dtype=torch.float64)
+ >>> bn = bn.to(dtype=torch.float64)
+ >>> Y = bn(X)
+ >>> Y.shape
+ torch.Size([8, 4, 4])
+ """
+
+ def __init__(
+ self,
+ num_features,
+ metric="AIM",
+ theta=1.0,
+ alpha=1.0,
+ beta=0.0,
+ momentum=0.1,
+ eps=1e-5,
+ n_iter=1,
+ congruence="cholesky",
+ device=None,
+ dtype=None,
+ ):
+ super().__init__()
+ supported_metrics = ("AIM", "LEM", "LCM")
+ if metric not in supported_metrics:
+ raise ValueError(
+ f"metric must be one of {supported_metrics}, got '{metric}'"
+ )
+ if congruence not in ("cholesky", "eig"):
+ raise ValueError(
+ f"congruence must be 'cholesky' or 'eig', got '{congruence}'"
+ )
+ self.num_features = num_features
+ self.metric = metric
+ self.theta = theta
+ self.alpha = alpha
+ self.beta = beta
+ self.momentum = momentum
+ self.eps = eps
+ self.n_iter = n_iter
+ self.congruence = congruence
+
+ self.bias = nn.Parameter(
+ torch.empty(1, num_features, num_features, device=device, dtype=dtype)
+ )
+ self.shift = nn.Parameter(torch.empty((), device=device, dtype=dtype))
+
+ if metric == "AIM":
+ self.register_buffer(
+ "running_mean",
+ torch.eye(num_features, device=device, dtype=dtype).unsqueeze(0),
+ )
+ else:
+ self.register_buffer(
+ "running_mean",
+ torch.zeros(1, num_features, num_features, device=device, dtype=dtype),
+ )
+ self.register_buffer("running_var", torch.ones((), device=device, dtype=dtype))
+
+ self.reset_parameters()
+ self._parametrize()
+
+ @torch.no_grad()
+ def reset_parameters(self):
+ self.bias.zero_()
+ self.bias[0].fill_diagonal_(1.0)
+ self.shift.fill_(1.0)
+
+ def _parametrize(self):
+ register_parametrization(self, "bias", SymmetricPositiveDefinite())
+ register_parametrization(self, "shift", PositiveDefiniteScalar())
+
+ # ------------------------------------------------------------------
+ # Thin dispatch helpers — each delegates to existing functional ops.
+ # ------------------------------------------------------------------
+
+ def _deform(self, X):
+ """Map SPD matrices to the Lie algebra."""
+ if self.metric == "AIM":
+ return X if self.theta == 1.0 else matrix_power.apply(X, self.theta)
+ if self.metric == "LEM":
+ return matrix_log.apply(X)
+ # LCM
+ Xp = X if self.theta == 1.0 else matrix_power.apply(X, self.theta)
+ return cholesky_log.apply(Xp)
+
+ def _inv_deform(self, S):
+ """Map from the Lie algebra back to SPD matrices."""
+ if self.metric == "AIM":
+ return S if self.theta == 1.0 else matrix_power.apply(S, 1.0 / self.theta)
+ if self.metric == "LEM":
+ return matrix_exp.apply(S)
+ # LCM
+ spd = ensure_sym(cholesky_exp.apply(S))
+ return spd if self.theta == 1.0 else matrix_power.apply(spd, 1.0 / self.theta)
+
+ def _translate(self, X, P, inverse=False):
+ """Group translation (centering / biasing) in the Lie algebra."""
+ if self.metric == "AIM":
+ if self.congruence == "cholesky":
+ return spd_cholesky_congruence(X, P, inverse=inverse)
+ # Eigendecomposition path
+ if inverse:
+ return spd_centering(X, matrix_inv_sqrt.apply(P))
+ return spd_rebiasing(X, matrix_sqrt.apply(P))
+ return X - P if inverse else X + P
+
+ def _frechet_mean(self, X_def):
+ """Fréchet mean in the deformed space."""
+ if self.metric == "AIM":
+ return frechet_mean(X_def, max_iter=self.n_iter)
+ return X_def.detach().mean(dim=0, keepdim=True)
+
+ def _scale(self, X, var):
+ """Variance normalization in the Lie algebra."""
+ factor = self.shift / (var + self.eps).sqrt()
+ if self.metric == "AIM":
+ return log_euclidean_scalar_multiply(factor, X)
+ return X * factor
+
+ # ------------------------------------------------------------------
+ # Running statistics & forward
+ # ------------------------------------------------------------------
+
+ def _update_running_stats(self, batch_mean, batch_var):
+ with torch.no_grad():
+ if self.metric == "AIM":
+ self.running_mean = airm_geodesic(
+ self.running_mean, batch_mean, self.momentum
+ )
+ else:
+ self.running_mean = (
+ 1 - self.momentum
+ ) * self.running_mean + self.momentum * batch_mean
+ self.running_var = (
+ 1 - self.momentum
+ ) * self.running_var + self.momentum * batch_var
+
+ def forward(self, X):
+ X_def = self._deform(X)
+ bias_def = self._deform(self.bias)
+
+ if self.training:
+ batch_mean = self._frechet_mean(X_def)
+ X_centered = self._translate(X_def, batch_mean, inverse=True)
+ if X.shape[0] > 1:
+ batch_var = lie_group_variance(
+ X_centered.detach(),
+ self.metric,
+ self.alpha,
+ self.beta,
+ self.theta,
+ )
+ X_scaled = self._scale(X_centered, batch_var)
+ else:
+ batch_var = self.running_var.clone()
+ X_scaled = X_centered
+ self._update_running_stats(batch_mean.detach(), batch_var.detach())
+ else:
+ X_centered = self._translate(X_def, self.running_mean, inverse=True)
+ X_scaled = self._scale(X_centered, self.running_var)
+
+ X_biased = self._translate(X_scaled, bias_def, inverse=False)
+ return self._inv_deform(X_biased)
+
+ def extra_repr(self):
+ return (
+ f"num_features={self.num_features}, metric={self.metric}, "
+ f"theta={self.theta}, alpha={self.alpha}, beta={self.beta}, "
+ f"momentum={self.momentum}, congruence={self.congruence}"
+ )
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 0e10aa1..092d54b 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -21,6 +21,7 @@
"SPDBatchNormMean": dict(num_features=10),
"BatchReNorm": dict(num_features=10),
"SPDBatchNormMeanVar": dict(num_features=10),
+ "SPDBatchNormLie": dict(num_features=10),
"PatchEmbeddingLayer": dict(n_chans=10, n_patches=2),
"BiMapIncreaseDim": dict(in_features=10, out_features=20),
"Shrinkage": dict(n_chans=10),
@@ -141,6 +142,8 @@ def test_module_dtype(module_name, dtype, device):
pytest.skip(
"PositiveDefiniteScalar is a scalar parametrization, not a matrix layer."
)
+ if module_name == "SPDBatchNormLie" and dtype.is_complex:
+ pytest.skip("SPDBatchNormLie only supports real-valued SPD matrices.")
module_class = getattr(spd_learn.modules, module_name)
mandatory_param = mandatory_parameters_per_module.get(module_name, {})
diff --git a/tests/test_liebn.py b/tests/test_liebn.py
new file mode 100644
index 0000000..8822f0d
--- /dev/null
+++ b/tests/test_liebn.py
@@ -0,0 +1,313 @@
+"""Tests for Lie Group Batch Normalization (LieBN) for SPD matrices.
+
+Verifies the theoretical guarantees from Chen et al., ICLR 2024
+(Proposition 4.2):
+ - Mean property: after centering+biasing with bias=I, output Frechet mean ≈ I
+ - Variance property: after scaling with shift=1, dispersion ≈ 1.0
+ - Running statistics converge to population statistics
+"""
+
+from math import sqrt
+
+import pytest
+import torch
+
+from spd_learn.functional import (
+ ensure_sym,
+ matrix_exp,
+ matrix_log,
+ vec_to_sym,
+)
+from spd_learn.functional.batchnorm import karcher_mean_iteration
+from spd_learn.modules import SPDBatchNormLie
+
+
+DTYPE = torch.float64
+
+# ---------------------------------------------------------------------------
+# Data fixture
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture()
+def simulated_data():
+ """Generate SPD data with known Frechet mean for testing.
+
+ Strategy: zero-mean tangent vectors -> matrix_exp -> SPD at Identity,
+ then apply linear mixing x = A z A^T so Frechet mean = A A^T.
+ """
+ ndim = 10
+ nobs = 128
+ generator = torch.Generator().manual_seed(42)
+
+ # Zero-mean tangent vectors -> SPD matrices centered at Identity
+ logz = vec_to_sym(
+ torch.randn((nobs, ndim * (ndim + 1) // 2), generator=generator, dtype=DTYPE)
+ )
+ logz = logz - logz.mean(dim=0, keepdim=True)
+ z = matrix_exp.apply(logz)
+
+ # Linear mixing model: shifts Frechet mean to A @ A^T
+ eps = 0.1
+ forward_model = (
+ torch.rand((ndim, ndim), generator=generator, dtype=DTYPE) - 0.5
+ ) * (1 - eps) + eps * torch.eye(ndim, dtype=DTYPE)
+ x = forward_model @ z @ forward_model.mT
+
+ # Analytic Frechet mean (by invariance)
+ x_mean_expected = (forward_model @ forward_model.mT).unsqueeze(0)
+
+ return x, x_mean_expected, ndim, nobs
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+METRICS = ["AIM", "LEM", "LCM"]
+CONGRUENCES = ["cholesky", "eig"]
+
+
+@pytest.mark.parametrize(
+ "metric,theta,atol",
+ [
+ ("AIM", 1.0, 1e-10),
+ ("LCM", 1.0, 1e-5), # Cholesky + log/exp introduces numerical error
+ ("LCM", 0.5, 1e-4), # Additional matrix_power roundtrip error
+ ],
+)
+def test_deform_inv_deform_roundtrip(simulated_data, metric, theta, atol):
+ """_inv_deform(_deform(X)) should recover X."""
+ x, _, ndim, _ = simulated_data
+ layer = SPDBatchNormLie(ndim, metric=metric, theta=theta, dtype=DTYPE)
+
+ X_def = layer._deform(x)
+ X_recovered = layer._inv_deform(X_def)
+
+ assert torch.allclose(X_recovered, x, atol=atol, rtol=0.0)
+
+
+@pytest.mark.parametrize("congruence", CONGRUENCES)
+@pytest.mark.parametrize("metric", METRICS)
+def test_post_normalization_mean(simulated_data, metric, congruence):
+ """After LieBN forward (bias=I, shift=1), codomain mean should be neutral.
+
+ - AIM: Karcher mean of output ≈ Identity
+ - LEM/LCM: arithmetic mean of deformed output ≈ zero matrix
+ """
+ x, _, ndim, nobs = simulated_data
+ layer = SPDBatchNormLie(
+ ndim, metric=metric, n_iter=64, congruence=congruence, dtype=DTYPE
+ )
+ layer.train()
+
+ with torch.no_grad():
+ output = layer(x)
+
+ tol = 2 * sqrt(1.0 / nobs)
+
+ if metric == "AIM":
+ # Compute Karcher mean of output
+ mean = output.mean(dim=0, keepdim=True)
+ for _ in range(64):
+ mean = karcher_mean_iteration(output, mean, detach=True)
+ identity = torch.eye(ndim, dtype=DTYPE).unsqueeze(0)
+ assert torch.allclose(mean, identity, atol=tol, rtol=0.0), (
+ f"AIM: Karcher mean of output deviates from Identity by "
+ f"{(mean - identity).abs().max().item():.6f}"
+ )
+ else:
+ # In codomain, mean should be ≈ zero
+ output_def = layer._deform(output)
+ codomain_mean = output_def.mean(dim=0, keepdim=True)
+ zeros = torch.zeros_like(codomain_mean)
+ assert torch.allclose(codomain_mean, zeros, atol=tol, rtol=0.0), (
+ f"{metric}: codomain mean deviates from zero by "
+ f"{codomain_mean.abs().max().item():.6f}"
+ )
+
+
+@pytest.mark.parametrize("metric", METRICS)
+def test_post_normalization_variance(simulated_data, metric):
+ """After LieBN forward (shift=1), output variance should be ≈ 1.0.
+
+ Theoretical: shift^2 * v^2 / (v^2 + eps). With shift=1 and large v^2,
+ this is close to 1.0.
+ """
+ x, _, ndim, nobs = simulated_data
+ layer = SPDBatchNormLie(ndim, metric=metric, n_iter=64, dtype=DTYPE)
+ layer.train()
+
+ with torch.no_grad():
+ output = layer(x)
+
+ # Compute variance of output in the same way as SPDBatchNormLie._frechet_variance
+ # but on the re-centered output
+ output_def = layer._deform(output)
+ if metric == "AIM":
+ output_mean = output_def.mean(dim=0, keepdim=True)
+ for _ in range(64):
+ output_mean = karcher_mean_iteration(output_def, output_mean, detach=True)
+ L = torch.linalg.cholesky(output_mean)
+ Y = torch.linalg.solve_triangular(L, output_def, upper=False)
+ centered = ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT)
+ logX = matrix_log.apply(centered)
+ frob_sq = (logX * logX).sum(dim=(-2, -1))
+ output_var = frob_sq.mean()
+ else:
+ centered = output_def - output_def.mean(dim=0, keepdim=True)
+ frob_sq = (centered * centered).sum(dim=(-2, -1))
+ output_var = frob_sq.mean()
+
+ # Expected: shift^2 * v^2 / (v^2 + eps) ≈ 1.0
+ tol = 3 * sqrt(1.0 / nobs)
+ assert abs(output_var.item() - 1.0) < tol, (
+ f"{metric}: output variance = {output_var.item():.6f}, expected ≈ 1.0"
+ )
+
+
+@pytest.mark.parametrize("metric", METRICS)
+def test_running_stats_single_batch(simulated_data, metric):
+ """With momentum=1.0, running stats should match batch stats exactly."""
+ x, _, ndim, nobs = simulated_data
+ layer = SPDBatchNormLie(ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE)
+ layer.train()
+
+ with torch.no_grad():
+ layer(x)
+
+ # Independently compute batch statistics
+ X_def = layer._deform(x)
+ if metric == "AIM":
+ expected_mean = X_def.mean(dim=0, keepdim=True)
+ for _ in range(64):
+ expected_mean = karcher_mean_iteration(X_def, expected_mean, detach=True)
+ else:
+ expected_mean = X_def.mean(dim=0, keepdim=True)
+
+ tol = sqrt(1.0 / nobs)
+ assert torch.allclose(layer.running_mean, expected_mean, atol=tol, rtol=0.0), (
+ f"{metric}: running_mean deviates from batch mean by "
+ f"{(layer.running_mean - expected_mean).abs().max().item():.6f}"
+ )
+
+ # Compute expected variance from centered data
+ if metric == "AIM":
+ L = torch.linalg.cholesky(expected_mean)
+ Y = torch.linalg.solve_triangular(L, X_def, upper=False)
+ centered = ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT)
+ logX = matrix_log.apply(centered)
+ frob_sq = (logX * logX).sum(dim=(-2, -1))
+ expected_var = frob_sq.mean()
+ else:
+ centered = X_def - expected_mean
+ frob_sq = (centered * centered).sum(dim=(-2, -1))
+ expected_var = frob_sq.mean()
+
+ assert torch.allclose(layer.running_var, expected_var, atol=tol, rtol=0.0), (
+ f"{metric}: running_var = {layer.running_var.item():.6f}, "
+ f"expected = {expected_var.item():.6f}"
+ )
+
+
+@pytest.mark.parametrize("metric", METRICS)
+def test_running_stats_convergence(simulated_data, metric):
+ """Running stats should converge to population stats over mini-batches."""
+ x, _, ndim, nobs = simulated_data
+ layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE)
+
+ # Full-batch reference statistics (high precision)
+ with torch.no_grad():
+ ref_layer = SPDBatchNormLie(
+ ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE
+ )
+ ref_layer.train()
+ ref_layer(x)
+ ref_mean = ref_layer.running_mean.clone()
+ ref_var = ref_layer.running_var.clone()
+
+ # Train with mini-batches and decaying momentum
+ ds = torch.utils.data.TensorDataset(x)
+ loader = torch.utils.data.DataLoader(ds, batch_size=nobs // 4, drop_last=True)
+
+ layer.train()
+ n_epochs = 64 // len(loader) * 4
+ for epoch in range(n_epochs):
+ layer.momentum = 1 / (epoch + 1)
+ for batch in loader:
+ with torch.no_grad():
+ layer(batch[0])
+
+ tol = 5 * sqrt(1.0 / nobs)
+ assert torch.allclose(layer.running_mean, ref_mean, atol=tol, rtol=0.0), (
+ f"{metric}: running_mean did not converge. "
+ f"Max deviation: {(layer.running_mean - ref_mean).abs().max().item():.6f}"
+ )
+ assert torch.allclose(layer.running_var, ref_var, atol=tol, rtol=0.05), (
+ f"{metric}: running_var did not converge. "
+ f"running={layer.running_var.item():.4f}, ref={ref_var.item():.4f}"
+ )
+
+
+@pytest.mark.parametrize("metric", METRICS)
+def test_gradient_flow(simulated_data, metric):
+ """Verify gradients flow through LieBN to input and parameters."""
+ x, _, ndim, _ = simulated_data
+ # Use a small batch to keep computation fast
+ x_small = x[:8].clone().requires_grad_(True)
+
+ layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE)
+ layer.train()
+
+ output = layer(x_small)
+ loss = output.sum()
+ loss.backward()
+
+ # Input gradient
+ assert x_small.grad is not None, f"{metric}: no gradient on input"
+ assert x_small.grad.abs().sum() > 0, f"{metric}: zero gradient on input"
+
+ # Bias parameter gradient (underlying unconstrained parameter)
+ bias_param = layer.parametrizations.bias.original
+ assert bias_param.grad is not None, f"{metric}: no gradient on bias"
+ assert bias_param.grad.abs().sum() > 0, f"{metric}: zero gradient on bias"
+
+ # Shift parameter gradient
+ shift_param = layer.parametrizations.shift.original
+ assert shift_param.grad is not None, f"{metric}: no gradient on shift"
+ assert shift_param.grad.abs().sum() > 0, f"{metric}: zero gradient on shift"
+
+
+@pytest.mark.parametrize("metric", METRICS)
+def test_default_initialization(metric):
+ """Verify default parameter initialization."""
+ ndim = 4
+ layer = SPDBatchNormLie(ndim, metric=metric, dtype=DTYPE)
+
+ # Bias should be Identity
+ identity = torch.eye(ndim, dtype=DTYPE).unsqueeze(0)
+ assert torch.allclose(layer.bias, identity, atol=1e-10), (
+ f"{metric}: bias not initialized to Identity"
+ )
+
+ # Shift should be 1.0
+ assert torch.allclose(layer.shift, torch.ones((), dtype=DTYPE), atol=1e-10), (
+ f"{metric}: shift not initialized to 1.0"
+ )
+
+ # Running mean: Identity for AIM, zeros for LEM/LCM
+ if metric == "AIM":
+ assert torch.allclose(layer.running_mean, identity, atol=1e-10)
+ else:
+ assert torch.allclose(
+ layer.running_mean, torch.zeros(1, ndim, ndim, dtype=DTYPE), atol=1e-10
+ )
+
+ # Running var should be 1.0
+ assert torch.allclose(layer.running_var, torch.ones((), dtype=DTYPE), atol=1e-10)
+
+ # Verify dtype propagated correctly
+ assert layer.bias.dtype == DTYPE
+ assert layer.shift.dtype == DTYPE
+ assert layer.running_mean.dtype == DTYPE
+ assert layer.running_var.dtype == DTYPE