diff --git a/docs/source/api.rst b/docs/source/api.rst index 4d4c59d..c8e8445 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -362,6 +362,7 @@ or related representations. SPDBatchNormMean SPDBatchNormMeanVar + SPDBatchNormLie BatchReNorm diff --git a/docs/source/conf.py b/docs/source/conf.py index fed715c..d7696fe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -250,6 +250,18 @@ from sphinx_gallery.sorting import ExplicitOrder +def _reset_torch_defaults(gallery_conf, fname): + """Reset torch global state between sphinx-gallery examples. + + Some examples call ``torch.set_default_dtype(torch.float64)`` which + persists across examples when run in the same worker process and + causes dtype-mismatch errors in subsequent examples. + """ + import torch + + torch.set_default_dtype(torch.float32) + + sphinx_gallery_conf = { "examples_dirs": ["../../examples"], "gallery_dirs": ["generated/auto_examples"], @@ -258,10 +270,11 @@ # Point 3: Image optimization - compress images and reduce thumbnail size "compress_images": ("images", "thumbnails"), "thumbnail_size": (400, 280), # Smaller thumbnails for faster loading - # Order: tutorials first, then visualizations, then applied examples + # Order: tutorials, how-to guides, visualizations, then applied examples "subsection_order": ExplicitOrder( [ "../../examples/tutorials", + "../../examples/howto", "../../examples/visualizations", "../../examples/applied_examples", ] @@ -277,6 +290,8 @@ # Include both plot_* files and tutorial_* files "filename_pattern": r"/(plot_|tutorial_)", "ignore_pattern": r"(__init__|spd_visualization_utils)\.py", + # Reset torch default dtype between examples to prevent float64 leakage + "reset_modules": ("matplotlib", "seaborn", _reset_torch_defaults), # Show signature link template (includes Colab launcher) "show_signature": False, # First cell in generated notebooks (for Colab compatibility) diff --git a/docs/source/geometric_concepts.rst b/docs/source/geometric_concepts.rst index 77c0379..8d286e8 100644 --- a/docs/source/geometric_concepts.rst +++ b/docs/source/geometric_concepts.rst @@ -693,6 +693,104 @@ where :math:`\frechet` is the Fréchet mean of the batch. See :ref:`sphx_glr_generated_auto_examples_visualizations_plot_batchnorm_animation.py` +Batch Normalization on SPD Manifolds +===================================== + +In Euclidean deep learning, batch normalization centers activations to zero mean +and unit variance, stabilizing gradient flow and accelerating convergence. On the +SPD manifold, the same principle applies — but "mean" and "variance" must respect +the curved Riemannian geometry. + +Why Euclidean BN Fails for SPD Matrices +---------------------------------------- + +Standard batch normalization computes :math:`\hat{x} = (x - \mu) / \sigma`. For SPD +matrices this is problematic: + +- **Subtraction breaks SPD**: :math:`X - M` (with :math:`M` the arithmetic mean) may not + be positive definite. +- **The swelling effect**: The Euclidean mean of SPD matrices can have a larger determinant + than any individual matrix, distorting the data distribution. +- **Scale mismatch**: SPD matrices from different subjects or sessions can have vastly + different spectral profiles; Euclidean normalization ignores this geometric structure. + +Riemannian Batch Normalization +------------------------------- + +:class:`~spd_learn.modules.SPDBatchNormMeanVar` addresses these issues by replacing +Euclidean operations with their Riemannian counterparts under the AIRM: + +1. **Centering**: Compute the Fréchet mean :math:`\frechet` of the batch, then + apply congruence :math:`\tilde{X}_i = \frechet^{-1/2} X_i \frechet^{-1/2}` to center + the batch around the identity matrix. +2. **Variance scaling**: Compute a scalar dispersion and normalize by a learnable weight. +3. **Biasing**: Apply a learnable SPD bias via congruence. + +This preserves the SPD structure at every step. + +Lie Group Batch Normalization (LieBN) +-------------------------------------- + +:class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn` generalizes +Riemannian BN by exploiting the Lie group structure of :math:`\spd`. The key insight +is that each Riemannian metric induces a different group action for centering and biasing. + +The LieBN forward pass follows five steps: + +1. **Deformation** — Map SPD matrices to a codomain via the metric + (e.g., :math:`\log(X)` for LEM, Cholesky + log-diagonal for LCM, :math:`X^\theta` for AIM). +2. **Centering** — Translate the batch to zero/identity mean using the group action. +3. **Scaling** — Normalize variance by a learnable dispersion parameter. +4. **Biasing** — Translate by a learnable location parameter. +5. **Inverse deformation** — Map back to the SPD manifold. + +.. list-table:: + :header-rows: 1 + :widths: 15 25 25 25 + + * - Metric + - Deformation + - Mean Computation + - Group Action + * - **LEM** + - :math:`\log(X)` + - Euclidean (closed-form) + - Additive + * - **LCM** + - Cholesky + log-diag + - Euclidean (closed-form) + - Additive + * - **AIM** + - :math:`X^\theta` + - Karcher (iterative) + - Cholesky congruence + +**Choosing a metric for batch normalization:** + +- **LEM**: Fastest (closed-form mean), good default for most tasks. +- **AIM**: Full affine invariance, best when data scale varies (e.g., cross-subject EEG). +- **LCM**: Fast like LEM, with Cholesky-based numerical stability. + +.. code-block:: python + + from spd_learn.modules import SPDBatchNormLie + + # LEM is the fastest — good default + bn_lem = SPDBatchNormLie(num_features=32, metric="LEM") + + # AIM for affine-invariant normalization + bn_aim = SPDBatchNormLie(num_features=32, metric="AIM", theta=1.0) + + # LCM for Cholesky stability + bn_lcm = SPDBatchNormLie(num_features=32, metric="LCM") + +.. seealso:: + + :ref:`tutorial-batch-normalization` — Hands-on tutorial comparing all BN strategies, + :ref:`howto-add-batchnorm` — Quick integration guide, + :ref:`liebn-batch-normalization` — Full benchmark reproduction across 3 datasets + + References ========== diff --git a/docs/source/references.bib b/docs/source/references.bib index 6fc0251..824aa12 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -138,6 +138,14 @@ @inproceedings{kobler2022spd url={https://proceedings.neurips.cc/paper_files/paper/2022/hash/28ef7ee7cd3e03093acc39e1272411b7-Abstract-Conference.html} } +@inproceedings{chen2024liebn, + title={A Lie Group Approach to Riemannian Batch Normalization}, + author={Chen, Ziheng and Song, Yue and Xu, Yunmei and Sebe, Nicu}, + booktitle={International Conference on Learning Representations}, + year={2024}, + url={https://openreview.net/forum?id=okYdj8Ysru} +} + @inproceedings{pan2022matt, title={MAtt: A manifold attention network for EEG decoding}, author={Pan, Yue-Ting and Chou, Jing-Lun and Wei, Chun-Shu}, diff --git a/examples/applied_examples/liebn_batch_normalization.py b/examples/applied_examples/liebn_batch_normalization.py new file mode 100644 index 0000000..4c610d6 --- /dev/null +++ b/examples/applied_examples/liebn_batch_normalization.py @@ -0,0 +1,1111 @@ +""" +.. _liebn-batch-normalization: + +Reproducing LieBN Paper Results (Table 4) +========================================== + +This example reproduces the SPDNet experiments from Table 4 of +Chen et al., "A Lie Group Approach to Riemannian Batch +Normalization", ICLR 2024 :cite:p:`chen2024liebn`. + +We compare batch normalization strategies on HDM05 (7 configs), Radar +(6 configs), and AFEW (7 configs) datasets. HDM05 and Radar follow the +paper's evaluation protocol (10 independent random-split runs with +batch-mean accuracy). AFEW uses a fixed train/val split (10 runs varying +only model initialization). + +**Configurations benchmarked:** + +- **SPDNet**: No batch normalization +- **SPDNetBN**: Riemannian BN (Brooks et al. + variance normalization) +- **LieBN-AIM**: LieBN under the Affine-Invariant Metric (theta=1, 1.5) +- **LieBN-LEM**: LieBN under the Log-Euclidean Metric +- **LieBN-LCM**: LieBN under the Log-Cholesky Metric (theta=1, 0.5, -0.5) + +.. note:: + + New to SPD batch normalization? Start with the + :ref:`tutorial-batch-normalization` tutorial for an introduction, or + see :ref:`howto-add-batchnorm` for a quick integration guide. + +.. contents:: This example covers: + :local: + :depth: 2 + +""" + +###################################################################### +# Setup and Imports +# ----------------- +# + +import json +import os +import random +import tarfile +import tempfile +import time +import urllib.request +import warnings +import zipfile + +from collections import defaultdict +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from joblib import Parallel, delayed +from torch.utils.data import DataLoader, TensorDataset + +from spd_learn.functional import ensure_sym +from spd_learn.modules import ( + BiMap, + LogEig, + ReEig, + SPDBatchNormLie, + SPDBatchNormMeanVar, +) + + +# Suppress noisy warnings from matplotlib and torch internals; +# keep UserWarning and RuntimeWarning visible for diagnostic signals. +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", module="matplotlib") + + +def set_reproducibility(seed=1024): + """Set random seeds and enable deterministic behavior.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + try: + torch.use_deterministic_algorithms(True, warn_only=True) + except TypeError: + torch.use_deterministic_algorithms(True) + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +GLOBAL_SEED = 1024 +set_reproducibility(GLOBAL_SEED) +DATA_DIR = Path("data") +DATA_DIR.mkdir(exist_ok=True) + +###################################################################### +# SPDNet Architecture & Training Setup +# ------------------------------------- +# +# We build a multi-layer SPDNet following the reference architecture: +# +# - Intermediate layers: ``BiMap -> [BN] -> ReEig`` +# - Final layer: ``BiMap -> [BN]`` (no ReEig) +# - Classifier: ``LogEig -> flatten -> Linear`` +# + + +def make_bn(n, bn_type, bn_kwargs): + """Create a batch normalization layer.""" + if bn_type == "SPDBN": + return SPDBatchNormMeanVar(n, momentum=bn_kwargs.get("momentum", 0.1)) + elif bn_type == "LieBN": + return SPDBatchNormLie(n, **bn_kwargs) + else: + raise ValueError(f"Unknown bn_type: {bn_type}") + + +class SPDNetModel(nn.Module): + """Multi-layer SPDNet with optional batch normalization. + + Parameters + ---------- + dims : list of int + Sequence of SPD matrix dimensions, e.g. [93, 30]. + n_classes : int + Number of output classes. + bn_type : str or None + None (no BN), 'SPDBN', or 'LieBN'. + bn_kwargs : dict or None + Keyword arguments for the BN layer. + """ + + def __init__(self, dims, n_classes, bn_type=None, bn_kwargs=None): + super().__init__() + layers = [] + for i in range(len(dims) - 1): + layers.append(BiMap(dims[i], dims[i + 1])) + if bn_type is not None: + layers.append(make_bn(dims[i + 1], bn_type, bn_kwargs or {})) + if i < len(dims) - 2: + layers.append(ReEig()) + self.features = nn.Sequential(*layers) + n_out = dims[-1] + # upper=False uses full n^2 features (matches reference's dims[-1]**2). + self.logeig = LogEig(upper=False, flatten=True) + self.classifier = nn.Linear(n_out**2, n_classes) + + def forward(self, x): + x = self.features(x) + x = self.logeig(x) + return self.classifier(x) + + +###################################################################### +# Training and Evaluation Utilities +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We define a single atomic training function that is dispatched via +# ``joblib`` with the loky (process-based) backend for true parallelism. +# +# **Checkpointing**: Results are saved per-run to a checkpoint file +# so that a crashed/interrupted experiment can be resumed without +# re-running completed work. +# +# **Performance tuning**: +# +# - Each worker uses 1 BLAS thread (``torch.set_num_threads(1)``) +# since Apple Silicon Accelerate's ``eigh`` is fastest single-threaded. +# With 14 workers this fully utilizes all CPU cores. +# - ``optimizer.zero_grad(set_to_none=True)`` avoids memset overhead. +# - Data arrives as numpy arrays for joblib memmapping (no pickle). +# + +N_RUNS = 10 +EPOCHS = 200 +BATCH_SIZE = 30 +LR = 5e-3 + +# Checkpoint file: per-run results saved as they complete. +try: + CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), "liebn_checkpoint.json") +except NameError: + CHECKPOINT_PATH = os.path.join(tempfile.gettempdir(), "liebn_checkpoint.json") + + +def _load_checkpoint(): + """Load existing checkpoint, return dict of completed runs.""" + if os.path.exists(CHECKPOINT_PATH): + with open(CHECKPOINT_PATH) as f: + return json.load(f) + return {} + + +def _save_checkpoint(checkpoint): + """Atomically save checkpoint (write tmp + rename).""" + tmp = CHECKPOINT_PATH + ".tmp" + with open(tmp, "w") as f: + json.dump(checkpoint, f, indent=2) + os.replace(tmp, CHECKPOINT_PATH) + + +def _train_single_run( + run_seed, + X_train, + y_train, + X_test, + y_test, + dims, + n_classes, + bn_type, + bn_kwargs, + epochs=EPOCHS, + batch_size=BATCH_SIZE, + lr=LR, +): + """Train and evaluate a single run (atomic unit of work). + + Data arrives as numpy arrays (for joblib memmapping) and is + converted to tensors inside the worker process. + + Returns (accuracy, fit_time) or (NaN, fit_time) on failure. + """ + # One BLAS thread per worker; eigh on Apple Silicon Accelerate is + # fastest single-threaded. With 14 workers we use all 14 cores. + torch.set_num_threads(1) + torch.set_default_dtype(torch.float64) + + random.seed(run_seed) + np.random.seed(run_seed) + torch.manual_seed(run_seed) + + # Convert numpy → tensors inside each worker process. + # Use torch.tensor() (not as_tensor) to copy from read-only memmaps. + X_train = torch.tensor(X_train, dtype=torch.float64) + y_train = torch.tensor(y_train, dtype=torch.long) + X_test = torch.tensor(X_test, dtype=torch.float64) + y_test = torch.tensor(y_test, dtype=torch.long) + + train_gen = torch.Generator() + train_gen.manual_seed(run_seed) + + model = SPDNetModel(dims, n_classes, bn_type=bn_type, bn_kwargs=bn_kwargs) + + train_loader = DataLoader( + TensorDataset(X_train, y_train), + batch_size=batch_size, + shuffle=True, + generator=train_gen, + ) + test_loader = DataLoader( + TensorDataset(X_test, y_test), + batch_size=batch_size, + shuffle=False, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True) + criterion = nn.CrossEntropyLoss() + + epoch_times = [] + try: + for epoch in range(epochs): + t0 = time.time() + model.train() + for xb, yb in train_loader: + optimizer.zero_grad(set_to_none=True) + loss = criterion(model(xb), yb) + loss.backward() + optimizer.step() + epoch_times.append(time.time() - t0) + + # Batch-mean accuracy (matches paper's training_script.py) + model.eval() + batch_accs = [] + with torch.no_grad(): + for xb, yb in test_loader: + acc = (model(xb).argmax(1) == yb).sum().item() / yb.shape[0] + batch_accs.append(acc) + + return np.mean(batch_accs) * 100.0, np.mean(epoch_times[-10:]) + except torch._C._LinAlgError as e: + warnings.warn(f"Run {run_seed} failed (LinAlgError): {e}") + fit_time = np.mean(epoch_times[-10:]) if epoch_times else 0.0 + return float("nan"), fit_time + except RuntimeError as e: + if "linalg" in str(e).lower() or "cholesky" in str(e).lower(): + warnings.warn(f"Run {run_seed} failed (linalg RuntimeError): {e}") + fit_time = np.mean(epoch_times[-10:]) if epoch_times else 0.0 + return float("nan"), fit_time + raise + + +###################################################################### +# Dataset Loading: HDM05 +# ---------------------- +# +# HDM05 contains 2086 pre-computed 93x93 SPD covariance matrices +# representing 117 motion capture classes. +# +# - Architecture: ``[93, 30]`` +# - Source: `HDM05 Motion Capture Database `_ +# + + +def download_and_extract(url, dest_dir, zip_name, extract_tgz=None): + """Download a zip file and extract it.""" + zip_path = dest_dir / zip_name + if not zip_path.exists(): + print(f"Downloading {zip_name}...") + urllib.request.urlretrieve(url, zip_path) + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + if extract_tgz: + tgz_path = dest_dir / extract_tgz + if tgz_path.exists(): + with tarfile.open(tgz_path, "r:gz") as tf: + tf.extractall(dest_dir) + + +def load_hdm05(data_dir): + """Load HDM05 dataset: pre-computed 93x93 SPD covariance matrices.""" + hdm_path = data_dir / "HDM05" + if not hdm_path.exists(): + download_and_extract( + "https://www.dropbox.com/scl/fi/x2ouxjwqj3zrb1idgkg2g/" + "HDM05.zip?rlkey=4f90ktgzfz28x3i2i4ylu6dvu&dl=1", + data_dir, + "HDM05.zip", + ) + names = sorted(f for f in os.listdir(hdm_path) if f.endswith(".npy")) + X_list, y_list = [], [] + for name in names: + x = np.load(hdm_path / name).real + label = int(name.split(".")[0].split("_")[-1]) + X_list.append(x) + y_list.append(label) + X = torch.from_numpy(np.stack(X_list)).double() + y = torch.from_numpy(np.array(y_list)).long() + print( + f"HDM05: {X.shape[0]} samples, {len(set(y_list))} classes, " + f"matrix size {X.shape[1]}x{X.shape[2]}" + ) + return X, y + + +X_hdm, y_hdm = load_hdm05(DATA_DIR) +eigvals_hdm = torch.linalg.eigvalsh(X_hdm) +print( + f"HDM05 SPD check: min eigenvalue = {eigvals_hdm.min():.2e}, " + f"max = {eigvals_hdm.max():.2e}" +) + + +###################################################################### +# HDM05 Experiments +# ----------------- +# +# We run 10 independent random-split experiments with 7 configurations +# matching Table 4b of the paper: SPDNet, SPDNetBN, AIM-(1), LEM-(1), +# LCM-(1), AIM-(1.5), and LCM-(0.5). +# +# Split: 50/50 train/test (random shuffle). +# Training: 200 epochs, batch_size=30, lr=5e-3, Adam with amsgrad. +# + +HDM05_DIMS = [93, 30] +HDM05_CLASSES = len(torch.unique(y_hdm)) +print(f"HDM05: dims={HDM05_DIMS}, n_classes={HDM05_CLASSES}") + +hdm05_configs = { + "SPDNet": {"bn_type": None, "bn_kwargs": None}, + "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}}, + "AIM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "AIM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LEM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LEM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "AIM-(1.5)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "AIM", + "theta": 1.5, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(0.5)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": 0.5, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, +} + +###################################################################### +# Dataset Loading: Radar +# ---------------------- +# +# The Radar dataset contains 3000 complex time-frequency signals +# (3 gesture classes), converted to 20x20 SPD matrices via +# covariance pooling. +# +# Signal processing pipeline: +# +# 1. Split complex signals into overlapping windows (size=20, hop=10) +# 2. Compute real covariance from complex windowed signal +# 3. Apply ReEig to ensure well-conditioned SPD output +# +# - Architecture: ``[20, 16, 12]`` +# + + +def _split_signal_cplx(x, window_size=20, hop_length=10): + """Window complex signals into overlapping segments. + + Input: (batch, 2, T) where dim=1 is [real, imag] + Output: (batch, 2, window_size, T') + """ + x_re = x[:, 0:1, :] + x_im = x[:, 1:2, :] + x_re_w = x_re.unfold(2, window_size, hop_length) + x_im_w = x_im.unfold(2, window_size, hop_length) + x_re_out = x_re_w.squeeze(1).permute(0, 2, 1) + x_im_out = x_im_w.squeeze(1).permute(0, 2, 1) + return torch.stack([x_re_out, x_im_out], dim=1) + + +def _cov_pool_cplx(f): + """Compute real covariance matrix from complex windowed signal. + + Input: (batch, 2, n, T) + Output: (batch, n, n) real SPD covariance matrix + """ + f_re = f[:, 0, :, :].double() + f_im = f[:, 1, :, :].double() + f_re = f_re - f_re.mean(-1, keepdim=True) + f_im = f_im - f_im.mean(-1, keepdim=True) + T = f.shape[-1] + X_Re = (f_re @ f_re.mT + f_im @ f_im.mT) / (T - 1) + return ensure_sym(X_Re) + + +def load_radar(data_dir): + """Load Radar dataset: complex signals -> 20x20 SPD covariance matrices.""" + radar_path = data_dir / "radar" + if not radar_path.exists(): + download_and_extract( + "https://www.dropbox.com/s/dfnlx2bnyh3kjwy/data.zip?e=1&dl=1", + data_dir, + "data.zip", + extract_tgz="data/radar.tgz", + ) + names = sorted(f for f in os.listdir(radar_path) if f.endswith(".npy")) + signals, labels = [], [] + for name in names: + x = np.load(radar_path / name) + x_ri = np.stack([x.real, x.imag], axis=0) + signals.append(x_ri) + labels.append(int(name.split(".")[0].split("_")[-1])) + signals_t = torch.from_numpy(np.stack(signals)).float() + y = torch.from_numpy(np.array(labels)).long() + with torch.no_grad(): + windowed = _split_signal_cplx(signals_t, window_size=20, hop_length=10) + X_cov = _cov_pool_cplx(windowed) + reeig = ReEig() + X = reeig(X_cov) + print( + f"Radar: {X.shape[0]} samples, {len(set(labels))} classes, " + f"matrix size {X.shape[1]}x{X.shape[2]}" + ) + return X, y + + +X_radar, y_radar = load_radar(DATA_DIR) +eigvals_radar = torch.linalg.eigvalsh(X_radar) +print( + f"Radar SPD check: min eigenvalue = {eigvals_radar.min():.2e}, " + f"max = {eigvals_radar.max():.2e}" +) + + +###################################################################### +# Dataset Loading: AFEW +# --------------------- +# +# AFEW (Acted Facial Expressions in the Wild) contains pre-computed +# 400x400 SPD covariance matrices for facial expression recognition +# (7 emotion classes). The dataset comes pre-split into train/val. +# +# - Architecture: ``[400, 200, 100, 50]`` +# - Source: `AFEW Dataset `_ +# + + +def load_afew(data_dir): + """Load AFEW dataset: pre-computed 400x400 SPD covariance matrices.""" + afew_path = data_dir / "afew" + if not afew_path.exists(): + # Try extracting from data/afew.tgz + tgz_candidates = [ + data_dir / "data" / "afew.tgz", + data_dir / "afew.tgz", + ] + for tgz in tgz_candidates: + if tgz.exists(): + with tarfile.open(tgz, "r:gz") as tf: + tf.extractall(data_dir) + break + else: + raise FileNotFoundError( + "AFEW data not found. Place afew.tgz in the data/ directory." + ) + + train_path = afew_path / "train" + val_path = afew_path / "val" + + def _load_split(split_path): + names = sorted(f for f in os.listdir(split_path) if f.endswith(".npy")) + X_list, y_list = [], [] + for name in names: + x = np.load(split_path / name).real + label = int(name.split(".")[0].split("_")[-1]) + X_list.append(x) + y_list.append(label) + return ( + torch.from_numpy(np.stack(X_list)).double(), + torch.from_numpy(np.array(y_list)).long(), + ) + + X_train, y_train = _load_split(train_path) + X_val, y_val = _load_split(val_path) + print( + f"AFEW: {X_train.shape[0]} train + {X_val.shape[0]} val samples, " + f"{len(set(y_train.tolist()) | set(y_val.tolist()))} classes, " + f"matrix size {X_train.shape[1]}x{X_train.shape[2]}" + ) + return X_train, y_train, X_val, y_val + + +X_afew_train, y_afew_train, X_afew_val, y_afew_val = load_afew(DATA_DIR) +eigvals_afew = torch.linalg.eigvalsh(X_afew_train) +print( + f"AFEW SPD check: min eigenvalue = {eigvals_afew.min():.2e}, " + f"max = {eigvals_afew.max():.2e}" +) + + +###################################################################### +# Radar Experiments +# ----------------- +# +# We run 10 independent random-split experiments with 6 configurations +# matching Table 4a: SPDNet, SPDNetBN, AIM-(1), LEM-(1), LCM-(1), +# and LCM-(-0.5). +# +# Split: 50/25/25 train/val/test (random shuffle; val discarded). +# We use ``test_ratio=0.25, val_ratio=0.25`` to match the paper's split. +# + +RADAR_DIMS = [20, 16, 12] +RADAR_CLASSES = len(torch.unique(y_radar)) +print(f"Radar: dims={RADAR_DIMS}, n_classes={RADAR_CLASSES}") + +radar_configs = { + "SPDNet": {"bn_type": None, "bn_kwargs": None}, + "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}}, + "AIM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "AIM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LEM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LEM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(-0.5)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": -0.5, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, +} + +###################################################################### +# AFEW Experiments +# ---------------- +# +# AFEW (Acted Facial Expressions in the Wild) has a fixed train/val split, +# so we run 10 experiments varying only model initialization. +# +# Architecture: ``[400, 200, 100, 50]`` (from the original SPDNet paper, +# Huang & Van Gool, 2017). +# +# .. note:: +# +# The LieBN paper (Table 4) uses the **FPHA** dataset (63x63, 45 classes) +# rather than AFEW (400x400, 7 classes). We include AFEW as an additional +# benchmark; no paper comparison numbers are available. +# + +AFEW_DIMS = [400, 200, 100, 50] +AFEW_CLASSES = len(set(y_afew_train.tolist()) | set(y_afew_val.tolist())) +print(f"AFEW: dims={AFEW_DIMS}, n_classes={AFEW_CLASSES}") + +afew_configs = { + "SPDNet": {"bn_type": None, "bn_kwargs": None}, + "SPDNetBN": {"bn_type": "SPDBN", "bn_kwargs": {"momentum": 0.1}}, + "AIM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "AIM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LEM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LEM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(1)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": 1.0, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "AIM-(1.5)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "AIM", + "theta": 1.5, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, + "LCM-(0.5)": { + "bn_type": "LieBN", + "bn_kwargs": { + "metric": "LCM", + "theta": 0.5, + "alpha": 1.0, + "beta": 0.0, + "momentum": 0.1, + }, + }, +} + +###################################################################### +# Run All Experiments (Parallelized with Checkpointing) +# ----------------------------------------------------- +# +# Experiments are dispatched per-dataset via ``joblib`` (loky backend) +# for true multi-process parallelism. Each worker uses 1 BLAS thread +# since Apple Silicon Accelerate's ``eigh`` is fastest single-threaded, +# with 14 workers fully utilizing all CPU cores. +# +# **Checkpointing**: After each dataset completes, all finished +# per-run results are saved to ``liebn_checkpoint.json``. On restart, +# already-completed (dataset, method, run) tuples are skipped. +# +# Total jobs: HDM05 (7 x 10) + Radar (6 x 10) + AFEW (7 x 10) = 200 +# + +checkpoint = _load_checkpoint() +n_cached = len(checkpoint) +if n_cached > 0: + print(f"Loaded checkpoint with {n_cached} completed runs.") + + +def _run_dataset_jobs(dataset_name, configs, jobs_and_keys, checkpoint): + """Dispatch jobs for one dataset, skipping already-checkpointed runs. + + Returns dict of {method: [(acc, fit_time), ...]} for this dataset. + Mutates ``checkpoint`` in-place with newly completed runs. + """ + + filtered_jobs = [] + filtered_keys = [] + cached_results = defaultdict(list) + + for (ds, method, run_idx), job in jobs_and_keys: + key = f"{ds}|{method}|{run_idx}" + if key in checkpoint: + acc = checkpoint[key]["acc"] + # Restore NaN for failed runs (stored as null in JSON). + if acc is None: + acc = float("nan") + cached_results[method].append((acc, checkpoint[key]["fit_time"])) + else: + filtered_jobs.append(job) + filtered_keys.append((ds, method, run_idx)) + + n_skip = len(jobs_and_keys) - len(filtered_jobs) + n_todo = len(filtered_jobs) + if n_skip > 0: + print(f" {dataset_name}: {n_skip} runs cached, {n_todo} remaining.") + if n_todo == 0: + print(f" {dataset_name}: all runs cached, nothing to do.") + else: + t0 = time.time() + # n_jobs=-1: 14 workers x 1 BLAS thread = 14 cores fully used. + raw = Parallel(n_jobs=-1, verbose=10)(filtered_jobs) + elapsed = time.time() - t0 + print(f" {dataset_name}: {n_todo} runs finished in {elapsed:.1f}s") + + # Save to checkpoint immediately. + for (ds, method, run_idx), result in zip(filtered_keys, raw): + key = f"{ds}|{method}|{run_idx}" + acc, ft = result + checkpoint[key] = { + "acc": None if np.isnan(acc) else acc, + "fit_time": ft, + "status": "failed" if np.isnan(acc) else "ok", + } + cached_results[method].append(result) + _save_checkpoint(checkpoint) + print(f" Checkpoint saved ({len(checkpoint)} total runs).") + + return dict(cached_results) + + +def _aggregate(runs): + """Aggregate per-run (acc, fit_time) tuples.""" + accs = [r[0] for r in runs if not np.isnan(r[0])] + fts = [r[1] for r in runs if not np.isnan(r[0])] + n_failed = sum(1 for r in runs if np.isnan(r[0])) + if n_failed > 0: + warnings.warn(f"{n_failed}/{len(runs)} runs failed (NaN)") + if not accs: + return { + "mean": 0.0, + "std": 0.0, + "max": 0.0, + "folds": [], + "fit_time": 0.0, + } + return { + "mean": np.mean(accs), + "std": np.std(accs), + "max": np.max(accs), + "folds": accs, + "fit_time": np.mean(fts), + } + + +# ---- Prepare HDM05 jobs (fixed 50/50 split) ---- +n_hdm = len(X_hdm) +rng_hdm = np.random.RandomState(GLOBAL_SEED) +perm_hdm = rng_hdm.permutation(n_hdm) +n_test_hdm = int(0.5 * n_hdm) +Xtr_hdm = X_hdm[perm_hdm[n_test_hdm:]].numpy() +ytr_hdm = y_hdm[perm_hdm[n_test_hdm:]].numpy() +Xte_hdm = X_hdm[perm_hdm[:n_test_hdm]].numpy() +yte_hdm = y_hdm[perm_hdm[:n_test_hdm]].numpy() + +hdm05_jobs = [] +for name, cfg in hdm05_configs.items(): + for i in range(N_RUNS): + job = delayed(_train_single_run)( + GLOBAL_SEED + i, + Xtr_hdm, + ytr_hdm, + Xte_hdm, + yte_hdm, + HDM05_DIMS, + HDM05_CLASSES, + cfg["bn_type"], + cfg["bn_kwargs"], + ) + hdm05_jobs.append((("HDM05", name, i), job)) + +# ---- Prepare Radar jobs (per-run 50/25/25 split) ---- +n_radar = len(X_radar) +n_test_radar = int(0.25 * n_radar) +n_val_radar = int(0.25 * n_radar) +X_radar_np = X_radar.numpy() +y_radar_np = y_radar.numpy() + +radar_jobs = [] +for name, cfg in radar_configs.items(): + for i in range(N_RUNS): + run_seed = GLOBAL_SEED + i + rng = np.random.RandomState(run_seed) + perm = rng.permutation(n_radar) + Xte_r = X_radar_np[perm[:n_test_radar]] + yte_r = y_radar_np[perm[:n_test_radar]] + Xtr_r = X_radar_np[perm[n_test_radar + n_val_radar :]] + ytr_r = y_radar_np[perm[n_test_radar + n_val_radar :]] + job = delayed(_train_single_run)( + run_seed, + Xtr_r, + ytr_r, + Xte_r, + yte_r, + RADAR_DIMS, + RADAR_CLASSES, + cfg["bn_type"], + cfg["bn_kwargs"], + ) + radar_jobs.append((("Radar", name, i), job)) + +# ---- Prepare AFEW jobs (fixed train/val split) ---- +X_afew_train_np = X_afew_train.numpy() +y_afew_train_np = y_afew_train.numpy() +X_afew_val_np = X_afew_val.numpy() +y_afew_val_np = y_afew_val.numpy() + +afew_jobs = [] +for name, cfg in afew_configs.items(): + for i in range(N_RUNS): + job = delayed(_train_single_run)( + GLOBAL_SEED + i, + X_afew_train_np, + y_afew_train_np, + X_afew_val_np, + y_afew_val_np, + AFEW_DIMS, + AFEW_CLASSES, + cfg["bn_type"], + cfg["bn_kwargs"], + ) + afew_jobs.append((("AFEW", name, i), job)) + +# ---- Run datasets sequentially, saving after each ---- +# This way Radar+HDM05 results are safe even if AFEW crashes. + +total_jobs = len(hdm05_jobs) + len(radar_jobs) + len(afew_jobs) +print( + f"\nTotal: {total_jobs} training runs " + f"({len(hdm05_configs)} + {len(radar_configs)} + " + f"{len(afew_configs)} methods x {N_RUNS} runs)." +) + +t_wall_start = time.time() + +print("\n--- HDM05 ---") +hdm05_raw = _run_dataset_jobs("HDM05", hdm05_configs, hdm05_jobs, checkpoint) +print("\n--- Radar ---") +radar_raw = _run_dataset_jobs("Radar", radar_configs, radar_jobs, checkpoint) +print("\n--- AFEW ---") +afew_raw = _run_dataset_jobs("AFEW", afew_configs, afew_jobs, checkpoint) + +t_wall = time.time() - t_wall_start +print(f"\nAll experiments finished in {t_wall:.1f}s") + +# ---- Aggregate per-method results ---- +hdm05_results = {m: _aggregate(r) for m, r in hdm05_raw.items()} +radar_results = {m: _aggregate(r) for m, r in radar_raw.items()} +afew_results = {m: _aggregate(r) for m, r in afew_raw.items()} + +for ds, results in [ + ("HDM05", hdm05_results), + ("Radar", radar_results), + ("AFEW", afew_results), +]: + print(f"\n{ds}:") + for m, r in results.items(): + print( + f" {m}: {r['mean']:.2f} +/- {r['std']:.2f} " + f"(max={r['max']:.2f}, fit_time={r['fit_time']:.2f}s)" + ) + + +###################################################################### +# Results Comparison & Visualization +# ----------------------------------- +# +# We compare our reproduction results against the paper's Table 4 numbers +# for HDM05 and Radar. AFEW results are shown without paper comparison +# (the paper uses FPHA, a different dataset). +# + +# Paper Table 4 numbers: (mean, std, max, fit_time) +paper_results = { + "Radar": { + "SPDNet": (93.25, 1.10, 94.4, 0.98), + "SPDNetBN": (94.85, 0.99, 96.13, 1.56), + "AIM-(1)": (95.47, 0.90, 96.27, 1.62), + "LEM-(1)": (94.89, 1.04, 96.8, 1.28), + "LCM-(1)": (93.52, 1.07, 95.2, 1.11), + "LCM-(-0.5)": (94.80, 0.71, 95.73, 1.43), + }, + "HDM05": { + "SPDNet": (59.13, 0.67, 60.34, 0.57), + "SPDNetBN": (66.72, 0.52, 67.66, 0.97), + "AIM-(1)": (67.79, 0.65, 68.75, 1.14), + "LEM-(1)": (65.05, 0.63, 66.05, 0.87), + "LCM-(1)": (66.68, 0.71, 68.52, 0.66), + "AIM-(1.5)": (68.16, 0.68, 69.25, 1.46), + "LCM-(0.5)": (70.84, 0.92, 72.27, 1.01), + }, + "AFEW": {}, +} + +radar_methods = list(radar_configs.keys()) +hdm05_methods = list(hdm05_configs.keys()) +afew_methods = list(afew_configs.keys()) + +###################################################################### +# Results Tables +# ~~~~~~~~~~~~~~ +# + + +def _print_table(dataset, methods, our_results, paper): + """Print comparison table for one dataset.""" + hdr = ( + f"{'Method':<14} | {'Fit Time':>8} | " + f"{'Mean+-STD (Ours)':>18} {'Max (Ours)':>10} | " + f"{'Mean+-STD (Paper)':>18} {'Max (Paper)':>11}" + ) + sep = "=" * len(hdr) + print(f"\n{dataset}") + print(sep) + print(hdr) + print(sep) + for m in methods: + ours = our_results.get(m, {}) + p = paper.get(m) + ft = f"{ours.get('fit_time', 0):.2f}" if ours else "---" + o_str = f"{ours['mean']:.2f}+-{ours['std']:.2f}" if ours else "---" + o_max = f"{ours['max']:.2f}" if ours else "---" + if p: + p_str = f"{p[0]:.2f}+-{p[1]:.2f}" + p_max = f"{p[2]:.2f}" + else: + p_str, p_max = "---", "---" + print(f"{m:<14} | {ft:>8} | {o_str:>18} {o_max:>10} | {p_str:>18} {p_max:>11}") + print(sep) + + +_print_table("Radar", radar_methods, radar_results, paper_results["Radar"]) +_print_table("HDM05", hdm05_methods, hdm05_results, paper_results["HDM05"]) +_print_table("AFEW", afew_methods, afew_results, paper_results["AFEW"]) + +###################################################################### +# Save results to JSON for reproducibility. +# + +try: + results_path = os.path.join(os.path.dirname(__file__), "liebn_table4_results.json") +except NameError: + results_path = os.path.join(tempfile.gettempdir(), "liebn_table4_results.json") +results_to_save = { + "radar": {k: dict(v) for k, v in radar_results.items()}, + "hdm05": {k: dict(v) for k, v in hdm05_results.items()}, + "afew": {k: dict(v) for k, v in afew_results.items()}, + "paper_results": { + ds: { + m: {"mean": v[0], "std": v[1], "max": v[2], "fit_time": v[3]} + for m, v in mv.items() + } + for ds, mv in paper_results.items() + if mv # skip empty (AFEW has no paper numbers) + }, +} +with open(results_path, "w") as f: + json.dump(results_to_save, f, indent=2) +print(f"\nResults saved to {results_path}") + +###################################################################### +# Comparison Bar Chart +# ~~~~~~~~~~~~~~~~~~~~ +# + +fig, axes = plt.subplots(1, 3, figsize=(22, 5)) + +dataset_info = [ + ("Radar", radar_results, radar_methods), + ("HDM05", hdm05_results, hdm05_methods), + ("AFEW", afew_results, afew_methods), +] + +for ax, (dataset, our_results, methods) in zip(axes, dataset_info): + x_pos = np.arange(len(methods)) + has_paper = bool(paper_results.get(dataset)) + + ours_means = [our_results[m]["mean"] for m in methods] + ours_stds = [our_results[m]["std"] for m in methods] + + if has_paper: + width = 0.35 + paper_means = [paper_results[dataset][m][0] for m in methods] + paper_stds = [paper_results[dataset][m][1] for m in methods] + + ax.bar( + x_pos - width / 2, + ours_means, + width, + yerr=ours_stds, + label="Ours", + capsize=3, + color="#3498db", + alpha=0.85, + ) + ax.bar( + x_pos + width / 2, + paper_means, + width, + yerr=paper_stds, + label="Paper", + capsize=3, + color="#e74c3c", + alpha=0.85, + ) + all_vals = ours_means + paper_means + else: + width = 0.5 + ax.bar( + x_pos, + ours_means, + width, + yerr=ours_stds, + label="Ours", + capsize=3, + color="#3498db", + alpha=0.85, + ) + all_vals = ours_means + + ax.set_xticks(x_pos) + ax.set_xticklabels(methods, rotation=30, ha="right") + ax.set_ylabel("Accuracy (%)") + ax.set_title(dataset) + ax.legend() + ax.grid(axis="y", alpha=0.3) + ymin = min(all_vals) - 5 + ymax = max(all_vals) + 3 + ax.set_ylim(ymin, ymax) + +plt.suptitle("LieBN Batch Normalization: SPDNet Results", fontweight="bold") +plt.tight_layout() +plt.show() + +###################################################################### +# References +# ---------- +# +# .. bibliography:: +# :filter: docname in docnames +# diff --git a/examples/applied_examples/plot_source_free_domain.py b/examples/applied_examples/plot_source_free_domain.py index eb53aee..4868a31 100644 --- a/examples/applied_examples/plot_source_free_domain.py +++ b/examples/applied_examples/plot_source_free_domain.py @@ -96,75 +96,10 @@ # SPDIM Geometric Operations # -------------------------- # -# We define two core geometric operations needed for the SPDIM pipeline. -# These will be included in a future release of ``spd_learn.functional``. +# The Fréchet mean and geodesic distances used by SPDIM are available +# directly from ``spd_learn.functional``. # -from spd_learn.functional import ( - get_epsilon, - matrix_exp, - matrix_log, - matrix_sqrt_inv, -) - - -def frechet_mean(X, max_iter=50, return_distances=False): - r"""Compute the Fréchet mean under the AIRM. - - .. math:: - - \bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n} - \sum_{i=1}^{N} d_{\text{AIRM}}^2(G, X_i) - - Uses adaptive step-size Karcher flow. - """ - eps = get_epsilon(X.dtype, "eigval_log") - n_samples = X.shape[0] - - if n_samples == 1: - mean = X[:1] - if return_distances: - return mean, torch.zeros(X.shape[:-2], dtype=X.dtype, device=X.device) - return mean - - w = torch.ones((*X.shape[:-2], 1, 1), dtype=X.dtype, device=X.device) - w = w / n_samples - G = (X * w).sum(dim=0, keepdim=True) - - nu = 1.0 - tau = float("inf") - - for _ in range(max_iter): - G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G) - X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt) - G_tangent = (X_tangent * w).sum(dim=0, keepdim=True) - - crit = torch.norm(G_tangent, p="fro", dim=(-2, -1)).max().item() - if crit <= eps: - break - - G = G_sqrt @ matrix_exp.apply(nu * G_tangent) @ G_sqrt - - h = nu * crit - if h < tau: - nu = 0.95 * nu - tau = h - else: - nu = 0.5 * nu - - if nu <= eps: - break - - if return_distances: - G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G) - X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt) - G_tangent = (X_tangent * w).sum(dim=0, keepdim=True) - distances = torch.norm(X_tangent - G_tangent, p="fro", dim=(-2, -1)) - return G, distances - - return G - - ###################################################################### # Loading the Dataset # ------------------- @@ -179,12 +114,13 @@ def frechet_mean(X, max_iter=50, return_distances=False): # - **Source domain**: Session A (training with labels) # - **Target domain**: Session B (adaptation without labels) # - from braindecode.datasets import create_from_X_y from moabb.datasets import BNCI2015_001 from moabb.paradigms import MotorImagery from sklearn.preprocessing import LabelEncoder +from spd_learn.functional import frechet_mean + dataset = BNCI2015_001() paradigm = MotorImagery( diff --git a/examples/howto/README.txt b/examples/howto/README.txt new file mode 100644 index 0000000..f213a19 --- /dev/null +++ b/examples/howto/README.txt @@ -0,0 +1,8 @@ +.. _howto_guides: + +How-to Guides +============= + +Task-oriented guides for solving specific problems with SPD Learn. Each guide +addresses one focused question and assumes you already understand the basics +(see :ref:`tutorials` if you're just getting started). diff --git a/examples/howto/plot_howto_add_batchnorm.py b/examples/howto/plot_howto_add_batchnorm.py new file mode 100644 index 0000000..1fb9a3a --- /dev/null +++ b/examples/howto/plot_howto_add_batchnorm.py @@ -0,0 +1,128 @@ +""" +.. _howto-add-batchnorm: + +How to Add Batch Normalization to an SPDNet +=========================================== + +Insert Riemannian batch normalization into an existing SPDNet pipeline +to stabilize training and improve convergence. + +**Prerequisites**: Familiarity with SPDNet building blocks +(see :ref:`tutorial-building-blocks`). + +""" + +###################################################################### +# The Problem +# ----------- +# +# You have a working SPDNet but training is unstable or converges slowly. +# Adding batch normalization after each ``BiMap`` layer can help. +# + +import torch +import torch.nn as nn + +from spd_learn.modules import BiMap, LogEig, ReEig, SPDBatchNormLie + + +###################################################################### +# Step 1: Choose Your Normalization Layer +# ---------------------------------------- +# +# spd_learn provides three batch normalization modules: +# +# .. list-table:: +# :header-rows: 1 +# :widths: 30 70 +# +# * - Module +# - When to Use +# * - :class:`~spd_learn.modules.SPDBatchNormMeanVar` +# - Standard choice. AIRM Frechet mean + variance scaling. +# * - :class:`~spd_learn.modules.SPDBatchNormLie` +# - Multiple metrics (AIM, LEM, LCM). Based on Lie group +# structure :cite:p:`chen2024liebn`. +# * - :class:`~spd_learn.modules.SPDBatchNormMean` +# - Mean-only centering (no variance scaling). Simplest option. +# + +###################################################################### +# Step 2: Insert After BiMap, Before ReEig +# ----------------------------------------- +# +# The standard placement is ``BiMap -> BN -> ReEig``. The final BiMap +# uses BN but skips ReEig: +# + +dims = [64, 32, 16] # your SPD matrix dimensions +layers = [] +for i in range(len(dims) - 1): + layers.append(BiMap(dims[i], dims[i + 1])) + layers.append(SPDBatchNormLie(dims[i + 1], metric="LEM")) + if i < len(dims) - 2: # no ReEig after last BiMap + layers.append(ReEig()) + +features = nn.Sequential(*layers) +print(features) + +###################################################################### +# Step 3: Complete Network +# ------------------------- +# +# Wrap the features with ``LogEig`` and a linear classifier: +# + + +class SPDNetWithBN(nn.Module): + """SPDNet with configurable batch normalization.""" + + def __init__(self, dims, n_classes, metric="LEM"): + super().__init__() + layers = [] + for i in range(len(dims) - 1): + layers.append(BiMap(dims[i], dims[i + 1])) + layers.append(SPDBatchNormLie(dims[i + 1], metric=metric)) + if i < len(dims) - 2: + layers.append(ReEig()) + self.features = nn.Sequential(*layers) + self.logeig = LogEig(upper=False, flatten=True) + self.classifier = nn.Linear(dims[-1] ** 2, n_classes) + + def forward(self, x): + return self.classifier(self.logeig(self.features(x))) + + +model = SPDNetWithBN([64, 32, 16], n_classes=4, metric="LEM") + +###################################################################### +# Verify it works with a dummy forward pass: + +X = torch.randn(8, 64, 64) +X = X @ X.mT + 0.01 * torch.eye(64) +out = model(X) +print(f"Input: {X.shape} -> Output: {out.shape}") # [8, 64, 64] -> [8, 4] + +###################################################################### +# Key Points +# ---------- +# +# - Place BN **after** ``BiMap`` and **before** ``ReEig`` +# - Use ``model.train()`` / ``model.eval()`` -- BN uses running stats +# at inference time +# - ``momentum=0.1`` (default) works well in most cases +# - Consider ``float64`` for numerical stability with Riemannian operations +# +# .. seealso:: +# +# - :ref:`tutorial-batch-normalization` -- Learn how BN works on SPD manifolds +# - :ref:`howto-choose-metric` -- Choosing between AIM, LEM, and LCM +# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference +# - :class:`~spd_learn.modules.SPDBatchNormMeanVar` -- API reference +# +# References +# ---------- +# +# .. bibliography:: +# :filter: docname in docnames +# diff --git a/examples/howto/plot_howto_choose_metric.py b/examples/howto/plot_howto_choose_metric.py new file mode 100644 index 0000000..533c313 --- /dev/null +++ b/examples/howto/plot_howto_choose_metric.py @@ -0,0 +1,267 @@ +""" +.. _howto-choose-metric: + +How to Choose a Metric for Batch Normalization +=============================================== + +Select the right Riemannian metric for :class:`~spd_learn.modules.SPDBatchNormLie`. +Each metric trades off speed, invariance, and numerical stability. + +**Prerequisites**: Familiarity with SPD batch normalization +(see :ref:`tutorial-batch-normalization`). + +""" + +###################################################################### +# Quick Decision Guide +# --------------------- +# +# .. list-table:: +# :header-rows: 1 +# :widths: 12 12 15 61 +# +# * - Metric +# - Speed +# - Invariance +# - Use When +# * - **LEM** +# - Fastest +# - Orthogonal +# - Default choice. Closed-form mean, good general performance. +# * - **AIM** +# - Slowest +# - Full affine +# - Data has varying scale (e.g., cross-subject EEG). +# * - **LCM** +# - Fast +# - Lower-triangular +# - Speed matters and you want Cholesky stability. +# + +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from pyriemann.datasets import make_gaussian_blobs + +from spd_learn.modules import SPDBatchNormLie + + +torch.manual_seed(42) + +# Generate SPD data using pyriemann (2-class, 2*n_matrices total samples) +n_matrices = 32 +n_dim = 8 +X_np, y = make_gaussian_blobs( + n_matrices=n_matrices, + n_dim=n_dim, + class_sep=1.5, + class_disp=0.5, + random_state=42, +) +X = torch.from_numpy(X_np).float() # shape: (64, 8, 8) + +metric_colors = {"LEM": "#2ecc71", "LCM": "#3498db", "AIM": "#e74c3c"} + +###################################################################### +# Comparing Forward Pass Speed +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The cost depends on the metric's mean computation. AIM requires an +# iterative Karcher mean, while LEM and LCM use closed-form means. +# We use larger 32x32 matrices here so timing differences are visible. +# + +n_bench = 32 +batch_size = 64 +A = torch.randn(batch_size, n_bench, n_bench) +X_bench = A @ A.mT + 0.01 * torch.eye(n_bench) + +timings = {} +for metric in ["LEM", "LCM", "AIM"]: + bn = SPDBatchNormLie(n_bench, metric=metric) + bn.train() + _ = bn(X_bench) # warmup + t0 = time.time() + for _ in range(20): + _ = bn(X_bench) + elapsed = (time.time() - t0) / 20 + timings[metric] = elapsed * 1000 + print( + f"{metric}: {elapsed * 1000:.1f} ms/batch ({n_bench}x{n_bench}, batch={batch_size})" + ) + +fig, ax = plt.subplots(figsize=(6, 4)) +ax.bar( + timings.keys(), + timings.values(), + color=[metric_colors[m] for m in timings], +) +ax.set_ylabel("Time (ms)") +ax.set_title("Forward Pass Speed by Metric") +ax.grid(axis="y", alpha=0.3) +plt.tight_layout() +plt.show() + +###################################################################### +# Comparing Output Properties +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Each metric normalizes eigenvalues differently. Inspect the +# eigenvalue distribution after normalization: +# + +fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True) +for ax, metric in zip(axes, ["LEM", "LCM", "AIM"]): + bn = SPDBatchNormLie(n_dim, metric=metric) + bn.train() + Y = bn(X) + eigvals = torch.linalg.eigvalsh(Y.detach()) + ax.boxplot( + [eigvals[:, i].numpy() for i in range(n_dim)], + positions=range(n_dim), + ) + ax.set_title(f"{metric}", color=metric_colors[metric], fontweight="bold") + ax.set_xlabel("Eigenvalue index") + ax.grid(True, alpha=0.3) + +axes[0].set_ylabel("Eigenvalue") +plt.suptitle("Eigenvalue Distribution After Normalization", fontweight="bold") +plt.tight_layout() +plt.show() + +###################################################################### +# Effect of Each Metric on an SPD Matrix +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To build visual intuition, let's examine how each metric transforms a +# single SPD matrix. The top row shows the matrix entries as a heatmap +# (before and after normalization), while the bottom row compares the +# eigenvalue spectrum. +# + +sample_idx = 0 +X_sample = X[sample_idx] # shape (8, 8) + +fig, axes = plt.subplots(2, 4, figsize=(16, 7)) + +# Top row: matrix heatmaps +im = axes[0, 0].imshow(X_sample.numpy(), cmap="RdBu_r", aspect="auto") +axes[0, 0].set_title("Original", fontweight="bold") +plt.colorbar(im, ax=axes[0, 0], shrink=0.8) + +normalized = {} +for col, metric in enumerate(["LEM", "LCM", "AIM"], start=1): + bn = SPDBatchNormLie(n_dim, metric=metric) + bn.train() + Y = bn(X) + Y_sample = Y[sample_idx].detach().numpy() + normalized[metric] = Y_sample + + ax = axes[0, col] + im = ax.imshow(Y_sample, cmap="RdBu_r", aspect="auto") + ax.set_title(f"After {metric}", fontweight="bold", color=metric_colors[metric]) + plt.colorbar(im, ax=ax, shrink=0.8) + +# Bottom row: eigenvalue bar charts +eigvals_orig = np.sort(np.linalg.eigvalsh(X_sample.numpy()))[::-1] +axes[1, 0].bar(range(n_dim), eigvals_orig, color="gray", alpha=0.8) +axes[1, 0].set_title("Eigenvalues", fontweight="bold") +axes[1, 0].set_xlabel("Index") +axes[1, 0].set_ylabel("Value") + +for col, metric in enumerate(["LEM", "LCM", "AIM"], start=1): + eigvals_after = np.sort(np.linalg.eigvalsh(normalized[metric]))[::-1] + ax = axes[1, col] + ax.bar( + range(n_dim), + eigvals_after, + color=metric_colors[metric], + alpha=0.8, + ) + ax.set_title( + f"{metric} Eigenvalues", fontweight="bold", color=metric_colors[metric] + ) + ax.set_xlabel("Index") + +plt.suptitle( + "Effect of Each Metric on a Single SPD Matrix", + fontweight="bold", + fontsize=13, +) +plt.tight_layout() +plt.show() + +###################################################################### +# Tuning the Theta Parameter +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The ``theta`` parameter controls the power deformation :math:`P^\theta`. +# It adjusts how strongly scale differences are compressed or amplified: +# +# - ``theta=0.5``: Square root -- compresses large scale differences +# - ``theta=1.0``: No deformation (default) +# - ``theta=1.5``: Amplifies scale differences +# +# The violin plots below show how the eigenvalue distribution changes +# with each theta value under the AIM metric: +# + +thetas = [0.5, 1.0, 1.5] +theta_colors = ["#3498db", "#9b59b6", "#f39c12"] + +fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True) + +for ax, theta, color in zip(axes, thetas, theta_colors): + bn = SPDBatchNormLie(n_dim, metric="AIM", theta=theta) + bn.train() + Y = bn(X) + eigvals = torch.linalg.eigvalsh(Y.detach()).numpy() + + parts = ax.violinplot( + [eigvals[:, i] for i in range(n_dim)], + positions=range(n_dim), + showmedians=True, + showextrema=True, + ) + for pc in parts["bodies"]: + pc.set_facecolor(color) + pc.set_alpha(0.6) + for key in ("cbars", "cmins", "cmaxes", "cmedians"): + parts[key].set_color(color) + + ax.set_title(f"θ = {theta}", fontweight="bold", fontsize=12) + ax.set_xlabel("Eigenvalue index") + ax.grid(True, alpha=0.3) + + print( + f"AIM (theta={theta}): eigval range [{eigvals.min():.3f}, {eigvals.max():.3f}]" + ) + +axes[0].set_ylabel("Eigenvalue") +plt.suptitle( + "Theta Parameter Effect on Eigenvalue Distribution (AIM Metric)", + fontweight="bold", +) +plt.tight_layout() +plt.show() + +###################################################################### +# Recommendations +# ---------------- +# +# 1. **Start with LEM** -- fastest, closed-form mean, works well in most cases +# 2. **Try AIM** if your data has varying scale across subjects or sessions +# 3. **Use LCM** when you need speed similar to LEM with Cholesky stability +# 4. **Tune theta** on a validation set when using AIM or LCM +# (see :ref:`liebn-batch-normalization` for a multi-dataset benchmark) +# +# .. seealso:: +# +# - :ref:`tutorial-batch-normalization` -- Detailed tutorial on SPD BN +# - :ref:`howto-add-batchnorm` -- How to add BN to your pipeline +# - :ref:`liebn-batch-normalization` -- Full benchmark reproduction +# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference +# diff --git a/examples/tutorials/tutorial_05_batch_normalization.py b/examples/tutorials/tutorial_05_batch_normalization.py new file mode 100644 index 0000000..9f01ebe --- /dev/null +++ b/examples/tutorials/tutorial_05_batch_normalization.py @@ -0,0 +1,374 @@ +""" +.. _tutorial-batch-normalization: + +Batch Normalization on SPD Manifolds +===================================== + +This tutorial teaches how batch normalization works on SPD matrices and +why it matters for training SPD neural networks. You will train a simple +SPDNet with and without normalization, observe the impact, and compare +different Riemannian metrics. + +By the end, you will understand: + +- Why standard Euclidean batch normalization doesn't apply to SPD matrices +- How Riemannian and Lie group batch normalization work +- When to use each metric (AIM, LEM, LCM) + +.. contents:: This tutorial covers: + :local: + :depth: 2 + +""" + +###################################################################### +# Setup and Imports +# ----------------- +# + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from spd_learn.modules import ( + BiMap, + LogEig, + ReEig, + SPDBatchNormLie, + SPDBatchNormMeanVar, +) + + +torch.manual_seed(42) +np.random.seed(42) + +###################################################################### +# Creating Synthetic SPD Data +# ---------------------------- +# +# We generate a 3-class classification problem with 8x8 SPD matrices. +# Each class has a distinct eigenvalue profile, simulating how real-world +# signals (EEG, radar) produce covariance matrices with different spectral +# characteristics. +# + + +def make_spd_dataset(n_samples_per_class=100, n=8, n_classes=3, seed=42): + """Generate synthetic SPD classification data.""" + rng = np.random.RandomState(seed) + X_list, y_list = [], [] + for c in range(n_classes): + eigvals = np.exp(rng.randn(n) * 0.5 + c * 0.3) + for _ in range(n_samples_per_class): + Q, _ = np.linalg.qr(rng.randn(n, n)) + S = Q @ np.diag(eigvals + rng.rand(n) * 0.1) @ Q.T + S = (S + S.T) / 2 + X_list.append(S) + y_list.append(c) + X = torch.from_numpy(np.stack(X_list)).float() + y = torch.from_numpy(np.array(y_list)) + perm = torch.randperm(len(X), generator=torch.Generator().manual_seed(seed)) + return X[perm], y[perm] + + +X, y = make_spd_dataset() +n_train = int(0.7 * len(X)) +X_train, y_train = X[:n_train], y[:n_train] +X_test, y_test = X[n_train:], y[n_train:] +print( + f"Dataset: {len(X)} samples ({n_train} train, {len(X) - n_train} test), " + f"{len(torch.unique(y))} classes, {X.shape[1]}x{X.shape[2]} SPD matrices" +) + + +###################################################################### +# A Simple SPDNet +# ---------------- +# +# We define a minimal SPDNet with one BiMap layer, optional batch +# normalization, ReEig activation, and a LogEig + linear classifier. +# + + +class SimpleSPDNet(nn.Module): + """Minimal SPDNet with optional batch normalization.""" + + def __init__(self, n_in, n_out, n_classes, bn=None): + super().__init__() + self.bimap = BiMap(n_in, n_out) + self.bn = bn + self.reeig = ReEig() + self.logeig = LogEig(upper=False, flatten=True) + self.classifier = nn.Linear(n_out**2, n_classes) + + def forward(self, x): + x = self.bimap(x) + if self.bn is not None: + x = self.bn(x) + x = self.reeig(x) + x = self.logeig(x) + return self.classifier(x) + + +def train_model(model, X_train, y_train, X_test, y_test, epochs=150, lr=5e-3): + """Train a model and record loss and accuracy curves.""" + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + train_losses, test_accs = [], [] + + for epoch in range(epochs): + model.train() + optimizer.zero_grad(set_to_none=True) + loss = criterion(model(X_train), y_train) + loss.backward() + optimizer.step() + train_losses.append(loss.item()) + + if (epoch + 1) % 5 == 0: + model.eval() + with torch.no_grad(): + acc = (model(X_test).argmax(1) == y_test).float().mean().item() + test_accs.append((epoch + 1, acc)) + + return train_losses, test_accs + + +###################################################################### +# Baseline: No Batch Normalization +# ---------------------------------- +# +# First, let's train without any normalization. Notice the loss curve +# and final accuracy. +# + +torch.manual_seed(42) +model_none = SimpleSPDNet(8, 4, 3, bn=None) +losses_none, accs_none = train_model(model_none, X_train, y_train, X_test, y_test) +print(f"No BN: final accuracy = {accs_none[-1][1]:.1%}") + + +###################################################################### +# Adding Riemannian Batch Normalization +# ---------------------------------------- +# +# :class:`~spd_learn.modules.SPDBatchNormMeanVar` normalizes using the +# Frechet mean and dispersion under the Affine-Invariant Riemannian +# Metric (AIRM). This is the Riemannian analogue of standard batch +# normalization. +# + +torch.manual_seed(42) +model_rbn = SimpleSPDNet(8, 4, 3, bn=SPDBatchNormMeanVar(4, momentum=0.1)) +losses_rbn, accs_rbn = train_model(model_rbn, X_train, y_train, X_test, y_test) +print(f"SPDBatchNormMeanVar: final accuracy = {accs_rbn[-1][1]:.1%}") + +###################################################################### +# Notice the improvement! Batch normalization stabilizes the loss +# trajectory and allows the network to converge to a better solution. +# + +###################################################################### +# Lie Group Batch Normalization +# ------------------------------ +# +# :class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn` +# exploits the Lie group structure of :math:`\spd`. Unlike +# ``SPDBatchNormMeanVar`` which only supports AIRM, LieBN supports three +# Riemannian metrics: +# +# - **AIM** (Affine-Invariant Metric): Iterative Karcher mean, full +# affine invariance. +# - **LEM** (Log-Euclidean Metric): Closed-form mean, fast computation. +# - **LCM** (Log-Cholesky Metric): Cholesky-based, numerically stable. +# +# Let's try all three: +# + +liebn_results = {} +for metric in ["AIM", "LEM", "LCM"]: + torch.manual_seed(42) + model = SimpleSPDNet(8, 4, 3, bn=SPDBatchNormLie(4, metric=metric)) + losses, accs = train_model(model, X_train, y_train, X_test, y_test) + liebn_results[metric] = (losses, accs) + print(f"LieBN ({metric}): final accuracy = {accs[-1][1]:.1%}") + + +###################################################################### +# Comparing Training Dynamics +# ---------------------------- +# +# Let's visualize how each normalization strategy affects convergence. +# + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + +# Loss curves +ax1.plot(losses_none, label="No BN", alpha=0.7, color="gray") +ax1.plot(losses_rbn, label="SPDBatchNormMeanVar", alpha=0.7, color="black") +colors = {"AIM": "#e74c3c", "LEM": "#2ecc71", "LCM": "#3498db"} +for metric, (losses, _) in liebn_results.items(): + ax1.plot(losses, label=f"LieBN ({metric})", alpha=0.7, color=colors[metric]) +ax1.set_xlabel("Epoch") +ax1.set_ylabel("Training Loss") +ax1.set_title("Loss Convergence") +ax1.legend() +ax1.grid(True, alpha=0.3) + +# Accuracy curves +for label, accs, color in [ + ("No BN", accs_none, "gray"), + ("SPDBatchNormMeanVar", accs_rbn, "black"), +]: + epochs, vals = zip(*accs) + ax2.plot(epochs, vals, "o-", label=label, color=color, markersize=3) +for metric, (_, accs) in liebn_results.items(): + epochs, vals = zip(*accs) + ax2.plot( + epochs, + vals, + "o-", + label=f"LieBN ({metric})", + color=colors[metric], + markersize=3, + ) +ax2.set_xlabel("Epoch") +ax2.set_ylabel("Test Accuracy") +ax2.set_title("Accuracy Over Training") +ax2.legend() +ax2.grid(True, alpha=0.3) +ax2.set_ylim(0, 1.05) + +plt.tight_layout() +plt.show() + +###################################################################### +# Inspecting the LieBN Pipeline +# -------------------------------- +# +# LieBN normalizes SPD matrices through five geometric steps: +# +# 1. **Deformation** -- map SPD matrices to a codomain via the chosen +# metric (e.g., matrix log for LEM, Cholesky + log-diag for LCM) +# 2. **Centering** -- translate the batch to zero/identity mean +# 3. **Scaling** -- normalize variance by a learnable dispersion parameter +# 4. **Biasing** -- translate by a learnable location parameter +# 5. **Inverse Deformation** -- map back to the SPD manifold +# +# The three metrics differ in *how* they perform deformation and centering: +# +# .. list-table:: +# :header-rows: 1 +# :widths: 15 25 25 25 +# +# * - Metric +# - Deformation +# - Mean +# - Group Action +# * - **LEM** +# - :math:`\log(X)` +# - Euclidean (closed-form) +# - Additive +# * - **LCM** +# - Cholesky + log-diag +# - Euclidean (closed-form) +# - Additive +# * - **AIM** +# - :math:`X^\theta` +# - Karcher (iterative) +# - Cholesky congruence +# +# Let's watch the running variance converge during training: +# + +torch.manual_seed(42) +A = torch.randn(64, 8, 8) +X_demo = A @ A.mT + 0.1 * torch.eye(8) + +fig, ax = plt.subplots(figsize=(8, 4)) +for metric in ["AIM", "LEM", "LCM"]: + bn = SPDBatchNormLie(8, metric=metric, momentum=0.1) + bn.train() + variances = [] + for epoch in range(50): + perm = torch.randperm(len(X_demo)) + for i in range(0, len(X_demo), 16): + batch = X_demo[perm[i : i + 16]] + if batch.shape[0] < 2: + continue + _ = bn(batch) + variances.append(bn.running_var.item()) + ax.plot(variances, label=metric, color=colors[metric]) + +ax.set_xlabel("Epoch") +ax.set_ylabel("Running Variance") +ax.set_title("Running Variance Convergence Across Metrics") +ax.legend() +ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +###################################################################### +# Notice that all three metrics converge, but LEM and LCM converge +# faster thanks to their closed-form mean computation. +# + +###################################################################### +# Verifying SPD Output and Gradient Flow +# ---------------------------------------- +# +# A critical property: batch normalization must produce valid SPD output +# with flowing gradients. Let's verify: +# + +torch.manual_seed(42) +A = torch.randn(8, 4, 4) +X_check = (A @ A.mT + 0.1 * torch.eye(4)).requires_grad_(True) + +for metric in ["AIM", "LEM", "LCM"]: + bn = SPDBatchNormLie(4, metric=metric) + bn.train() + out = bn(X_check) + loss = (out * out).sum() + loss.backward() + eigvals = torch.linalg.eigvalsh(out.detach()) + print( + f"{metric}: min_eigval={eigvals.min():.2e}, grad_norm={X_check.grad.norm():.4f}" + ) + X_check.grad = None + +###################################################################### +# All eigenvalues are positive (valid SPD) and gradients flow +# correctly through the normalization layer. +# + +###################################################################### +# Summary +# ------- +# +# In this tutorial you learned: +# +# - **Why**: Batch normalization stabilizes SPDNet training by +# normalizing the distribution of SPD activations +# - **How**: Riemannian BN uses the Frechet mean and variance; +# Lie group BN generalizes this to multiple metrics +# - **Which metric**: LEM for speed, AIM for invariance, LCM for +# Cholesky stability +# +# Next steps: +# +# .. seealso:: +# +# - :ref:`howto-add-batchnorm` -- Add BN to an existing pipeline +# - :ref:`howto-choose-metric` -- Decision guide for metric selection +# - :ref:`liebn-batch-normalization` -- Full benchmark reproduction +# across HDM05, Radar, and AFEW datasets +# - :class:`~spd_learn.modules.SPDBatchNormLie` -- API reference +# - :class:`~spd_learn.modules.SPDBatchNormMeanVar` -- API reference +# +# References +# ---------- +# +# .. bibliography:: +# :filter: docname in docnames +# diff --git a/spd_learn/__init__.py b/spd_learn/__init__.py index 9b68aad..9a02dcf 100644 --- a/spd_learn/__init__.py +++ b/spd_learn/__init__.py @@ -22,6 +22,7 @@ PatchEmbeddingLayer, ReEig, Shrinkage, + SPDBatchNormLie, SPDBatchNormMean, SPDBatchNormMeanVar, SPDDropout, @@ -50,13 +51,14 @@ "BatchReNorm", "BiMap", "BiMapIncreaseDim", - "SPDBatchNormMean", "CovLayer", "ExpEig", "LogEig", "PatchEmbeddingLayer", "ReEig", "Shrinkage", + "SPDBatchNormLie", + "SPDBatchNormMean", "SPDBatchNormMeanVar", "SPDDropout", "TraceNorm", diff --git a/spd_learn/functional/__init__.py b/spd_learn/functional/__init__.py index 97b4f00..0182bac 100644 --- a/spd_learn/functional/__init__.py +++ b/spd_learn/functional/__init__.py @@ -16,8 +16,11 @@ from .autograd import modeig_backward, modeig_forward from .batchnorm import ( + frechet_mean, karcher_mean_iteration, + lie_group_variance, spd_centering, + spd_cholesky_congruence, spd_rebiasing, tangent_space_variance, ) @@ -155,8 +158,11 @@ "ledoit_wolf", "shrinkage_covariance", # Batch normalization + "frechet_mean", "karcher_mean_iteration", + "lie_group_variance", "spd_centering", + "spd_cholesky_congruence", "spd_rebiasing", "tangent_space_variance", # Bilinear operations diff --git a/spd_learn/functional/batchnorm.py b/spd_learn/functional/batchnorm.py index 16c4851..71eae7e 100644 --- a/spd_learn/functional/batchnorm.py +++ b/spd_learn/functional/batchnorm.py @@ -1,5 +1,6 @@ # Copyright (c) 2024-now SPD Learn Developers # SPDX-License-Identifier: BSD-3-Clause + """Functional operations for SPD batch normalization. This module provides stateless mathematical operations for Riemannian batch @@ -8,12 +9,18 @@ Functions --------- +frechet_mean + Fréchet mean of SPD matrices under the AIRM via Karcher flow. karcher_mean_iteration Single iteration of the Karcher (Fréchet) mean algorithm. spd_centering Center SPD matrices around a given mean via congruence transformation. +spd_cholesky_congruence + Congruence transformation using the Cholesky factor of an SPD matrix. tangent_space_variance Compute variance of SPD matrices in the tangent space. +lie_group_variance + Fréchet variance under a Lie group structure on the SPD manifold. See Also -------- @@ -21,16 +28,20 @@ :class:`~spd_learn.modules.SPDBatchNormMeanVar` : Full Riemannian batch normalization. """ +from typing import Optional, Tuple, Union + import torch from .core import matrix_exp, matrix_log, matrix_sqrt_inv +from .utils import ensure_sym def karcher_mean_iteration( X: torch.Tensor, current_mean: torch.Tensor, detach: bool = True, -) -> torch.Tensor: + return_tangent: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Perform one iteration of the Karcher mean algorithm. The Karcher (Fréchet) mean on the SPD manifold is the minimizer of the sum @@ -54,11 +65,15 @@ def karcher_mean_iteration( If True, detaches ``current_mean`` from the computational graph before computing the update. Set to False when gradients with respect to the mean are needed. + return_tangent : bool, default=False + If True, also returns the mean tangent update used in this Karcher step. Returns ------- - torch.Tensor - Updated Karcher mean estimate with shape `(1, ..., n, n)`. + torch.Tensor or Tuple[torch.Tensor, torch.Tensor] + Updated Karcher mean estimate with shape `(1, ..., n, n)`. When + ``return_tangent=True``, also returns the mean tangent update with the + same shape. Notes ----- @@ -85,9 +100,85 @@ def karcher_mean_iteration( mean_tangent = X_tangent.mean(dim=0, keepdim=True) # Map back to manifold new_mean = mean_sqrt @ matrix_exp.apply(mean_tangent) @ mean_sqrt + if return_tangent: + return new_mean, mean_tangent return new_mean +def frechet_mean( + X: torch.Tensor, + max_iter: int = 1, + weights: Optional[torch.Tensor] = None, + return_distances: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Fréchet mean of SPD matrices under the AIRM via Karcher flow. + + Computes the minimizer of the sum of squared geodesic distances: + + .. math:: + + \bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n} + \sum_{i=1}^{N} w_i \, d_{\text{AIRM}}^2(G, X_i) + + using iterative Karcher flow initialized from the (weighted) Euclidean mean. + + Parameters + ---------- + X : torch.Tensor + Batch of SPD matrices with shape ``(batch_size, ..., n, n)``. + max_iter : int, default=1 + Number of Karcher flow iterations. A single iteration is often + sufficient for batch normalization; use more (e.g. 50) when a + high-accuracy mean is needed. + weights : torch.Tensor, optional + Per-sample weights with shape broadcastable to ``X``. When ``None``, + uniform weights ``1/N`` are used. + return_distances : bool, default=False + If True, also returns the geodesic distances from each sample to + the mean. + + Returns + ------- + mean : torch.Tensor + Fréchet mean with shape ``(1, ..., n, n)``. + distances : torch.Tensor + Only returned when ``return_distances=True``. Geodesic distances + from each sample to the mean, with shape ``(batch_size, ...)``. + + See Also + -------- + :func:`karcher_mean_iteration` : Single Karcher step (lower-level). + :func:`~spd_learn.functional.airm_distance` : Pairwise AIRM distance. + + References + ---------- + See :cite:p:`pennec2006riemannian` for details on Karcher mean computation. + """ + batch = X.detach() + + if weights is None: + mean = batch.mean(dim=0, keepdim=True) + else: + mean = (batch * weights).sum(dim=0, keepdim=True) + + for _ in range(max_iter): + mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean) + X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt) + if weights is None: + mean_tangent = X_tangent.mean(dim=0, keepdim=True) + else: + mean_tangent = (X_tangent * weights).sum(dim=0, keepdim=True) + mean = mean_sqrt @ matrix_exp.apply(mean_tangent) @ mean_sqrt + + if return_distances: + mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean) + X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt) + distances = torch.norm(X_tangent, p="fro", dim=(-2, -1)) + return mean, distances + + return mean + + def spd_centering( X: torch.Tensor, mean_invsqrt: torch.Tensor, @@ -209,9 +300,126 @@ def tangent_space_variance( return variance +def spd_cholesky_congruence( + X: torch.Tensor, + P: torch.Tensor, + inverse: bool = False, +) -> torch.Tensor: + r"""Congruence transformation using the Cholesky factor of an SPD matrix. + + Given an SPD matrix :math:`P = LL^T`, applies: + + .. math:: + + \text{forward: } Y = LXL^T, \qquad + \text{inverse: } Y = L^{-1}X L^{-T} + + This implements the Lie group action of ``GL(n)`` on the SPD manifold and + is used for centering and biasing under the affine-invariant metric. + + Parameters + ---------- + X : torch.Tensor + Batch of SPD matrices with shape `(..., n, n)`. + P : torch.Tensor + SPD matrix whose Cholesky factor defines the transformation, + with shape broadcastable to ``X``. + inverse : bool, default=False + If True, applies the inverse congruence :math:`L^{-1}X L^{-T}`. + + Returns + ------- + torch.Tensor + Transformed SPD matrices with the same shape as ``X``. + + See Also + -------- + :func:`spd_centering` : Eigendecomposition-based centering (uses :math:`M^{-1/2}`). + """ + L = torch.linalg.cholesky(P) + if inverse: + Y = torch.linalg.solve_triangular(L, X, upper=False) + return ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT) + return ensure_sym(L @ X @ L.mT) + + +def lie_group_variance( + X_centered: torch.Tensor, + metric: str, + alpha: float = 1.0, + beta: float = 0.0, + theta: float = 1.0, +) -> torch.Tensor: + r"""Fréchet variance under a Lie group structure on the SPD manifold. + + Computes the scalar dispersion of centered data in the Lie algebra, + using the bi-invariant distance of Chen et al. :cite:p:`chen2024liebn`: + + .. math:: + + \sigma^2 = \frac{1}{N} \sum_i + \bigl(\alpha \lVert V_i \rVert_F^2 + \beta \, g(V_i)^2\bigr) + \;/\; \theta^2 + + where the auxiliary term :math:`g` depends on the metric: + + - **AIM**: :math:`V_i = \log(X_i)`, :math:`g(V) = \log\det(X)` + - **LEM**: :math:`V_i = X_i` (already in log space), :math:`g(V) = \operatorname{tr}(V)`, + no :math:`\theta` scaling + - **LCM**: same as LEM but with :math:`\theta` scaling + + Parameters + ---------- + X_centered : torch.Tensor + Centered data in the Lie algebra with shape `(batch_size, ..., n, n)`. + For AIM these are SPD matrices (centered around identity); for LEM/LCM + these are symmetric / lower-triangular matrices. + metric : {"AIM", "LEM", "LCM"} + Lie group structure. + alpha : float, default=1.0 + Frobenius-norm weight. + beta : float, default=0.0 + Trace / log-determinant weight. + theta : float, default=1.0 + Power deformation parameter. + + Returns + ------- + torch.Tensor + Scalar variance (0-d tensor). + + See Also + -------- + :func:`tangent_space_variance` : Unweighted tangent-space dispersion used + by :class:`~spd_learn.modules.SPDBatchNormMeanVar`. + """ + if metric == "AIM": + logX = matrix_log.apply(X_centered) + frob_sq = (logX * logX).sum(dim=(-2, -1)) + dists = alpha * frob_sq + if beta != 0: + dists = dists + beta * torch.logdet(X_centered).square() + return dists.mean() / (theta**2) + + frob_sq = (X_centered * X_centered).sum(dim=(-2, -1)) + dists = alpha * frob_sq + if beta != 0: + trace = X_centered.diagonal(dim1=-2, dim2=-1).sum(dim=-1) + dists = dists + beta * trace.square() + var = dists.mean() + if metric == "LCM": + var = var / (theta**2) + elif metric != "LEM": + raise ValueError(f"metric must be 'AIM', 'LEM', or 'LCM', got '{metric}'") + return var + + __all__ = [ + "frechet_mean", "karcher_mean_iteration", + "lie_group_variance", "spd_centering", + "spd_cholesky_congruence", "spd_rebiasing", "tangent_space_variance", ] diff --git a/spd_learn/modules/__init__.py b/spd_learn/modules/__init__.py index 36a5b85..8183fd5 100644 --- a/spd_learn/modules/__init__.py +++ b/spd_learn/modules/__init__.py @@ -4,6 +4,7 @@ from .bilinear import BiMap, BiMapIncreaseDim from .covariance import CovLayer from .dropout import SPDDropout +from .liebn import SPDBatchNormLie from .manifold import PositiveDefiniteScalar, SymmetricPositiveDefinite from .modeig import ExpEig, LogEig, ReEig from .regularize import Shrinkage, TraceNorm @@ -29,6 +30,7 @@ "SPDBatchNormMean", "BatchReNorm", "SPDBatchNormMeanVar", + "SPDBatchNormLie", # dropout "SPDDropout", # residual diff --git a/spd_learn/modules/batchnorm.py b/spd_learn/modules/batchnorm.py index c87871c..e2c051c 100644 --- a/spd_learn/modules/batchnorm.py +++ b/spd_learn/modules/batchnorm.py @@ -13,7 +13,7 @@ matrix_sqrt, ) from ..functional.batchnorm import ( - karcher_mean_iteration, + frechet_mean, spd_centering, spd_rebiasing, tangent_space_variance, @@ -194,10 +194,7 @@ def forward(self, input): """ if self.training: - mean = input.mean(dim=0, keepdim=True) - if input.shape[0] > 1: - for _ in range(self.n_iter): - mean = karcher_mean_iteration(input, mean) + mean = frechet_mean(input, max_iter=self.n_iter) with torch.no_grad(): self.running_mean = airm_geodesic( self.running_mean, mean, self.momentum @@ -478,14 +475,8 @@ def forward(self, input): Normalized tensor of the same shape as the input. """ - n_samples = input.shape[0] if self.training: - # Kobler et al. SPDMBN/SPDBN: estimate batch Fréchet mean via Karcher step - batch_mean = input.mean(dim=0, keepdim=True) - if n_samples > 1: - for _ in range(self.n_iter): - # Kobler et al. (Eq. 4): P2 L132-145; Karcher flow note: P2 L163-165 - batch_mean = karcher_mean_iteration(input, batch_mean) + batch_mean = frechet_mean(input, max_iter=self.n_iter) # Scalar dispersion: mean squared Frobenius norm of log at the mean (a single scalar, not variance matrix) mean_inv_sqrt = matrix_inv_sqrt.apply(batch_mean) diff --git a/spd_learn/modules/liebn.py b/spd_learn/modules/liebn.py new file mode 100644 index 0000000..1c8fdc0 --- /dev/null +++ b/spd_learn/modules/liebn.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024-now SPD Learn Developers +# SPDX-License-Identifier: BSD-3-Clause + +"""Lie Group Batch Normalization for SPD matrices. + +This module implements SPDBatchNormLie based on: +Ziheng Chen, Yue Song, Yunmei Liu, and Nicu Sebe, +"A Lie Group Approach to Riemannian Batch Normalization," ICLR 2024. + +The implementation is integrated into ``spd_learn`` from the original LieBN +repository: https://github.com/GitZH-Chen/LieBN/tree/main/LieBN +""" + +import torch + +from torch import nn +from torch.nn.utils.parametrize import register_parametrization + +from ..functional import ( + airm_geodesic, + ensure_sym, + matrix_exp, + matrix_inv_sqrt, + matrix_log, + matrix_power, + matrix_sqrt, +) +from ..functional.batchnorm import ( + frechet_mean, + lie_group_variance, + spd_centering, + spd_cholesky_congruence, + spd_rebiasing, +) +from ..functional.metrics import ( + cholesky_exp, + cholesky_log, + log_euclidean_scalar_multiply, +) +from .manifold import PositiveDefiniteScalar, SymmetricPositiveDefinite + + +class SPDBatchNormLie(nn.Module): + r"""Lie Group Batch Normalization for SPD matrices. + + Implements the LieBN framework :cite:p:`chen2024liebn` for SPD manifolds. + Unlike :class:`SPDBatchNormMeanVar`, which normalizes under a single + Riemannian metric (AIRM), this layer exploits the **Lie group structure** + of three classical SPD geometries to define centering, scaling, and biasing + as group-theoretic operations with formal statistical guarantees. + + **Algorithm.** + Given a batch :math:`\{P_i\}_{i=1}^N \subset \mathcal{S}_{++}^n`, the + forward pass applies three steps in the Lie algebra selected by ``metric``: + + 1. **Centering** -- translate the batch mean :math:`M` to the group + identity :math:`E` via the inverse left translation: + + .. math:: + + \bar{P}_i = L_{M_\odot^{-1}}(P_i) + + 2. **Scaling** -- normalize the Fréchet variance :math:`v^2` with a + learnable shift :math:`s \in \mathbb{R}_{>0}`: + + .. math:: + + \hat{P}_i = \operatorname{Exp}_E + \!\left[\frac{s}{\sqrt{v^2 + \epsilon}}\, + \operatorname{Log}_E(\bar{P}_i)\right] + + 3. **Biasing** -- translate to the learnable SPD parameter :math:`B`: + + .. math:: + + \tilde{P}_i = L_B(\hat{P}_i) + + **Theoretical guarantees** (Proposition 4.2 of the paper): + + * *Mean control*: after centering and biasing with :math:`B = E`, + the Fréchet mean of the output batch equals :math:`E`. + * *Variance control*: after scaling, the output dispersion satisfies + :math:`\sum_i w_i\,d^2(\hat{P}_i, E) = s^2`. + + **Supported metrics.** + The ``metric`` parameter selects one of three Lie group structures, each + inducing a family of parameterized metrics via the power deformation + :math:`\mathrm{P}_\theta`. The table below summarizes how each step is + realized (see Table 2 in :cite:p:`chen2024liebn`): + + .. list-table:: + :header-rows: 1 + :widths: 25 25 25 25 + + * - Operation + - :math:`(\theta,\alpha,\beta)`-AIM + - :math:`(\alpha,\beta)`-LEM + - :math:`\theta`-LCM + * - Pullback map + - :math:`\mathrm{P}_\theta` + - :math:`\operatorname{mlog}` + - :math:`\psi_{\mathrm{LC}} \circ \mathrm{P}_\theta` + * - Left translation :math:`L_Q(P)` + - :math:`Q^{1/2} P\, Q^{1/2}` + - :math:`P + Q` + - :math:`P + Q` + * - Scaling + - :math:`\operatorname{Exp}_I[s\,\operatorname{Log}_I(P)]` + - :math:`s \cdot P` + - :math:`s \cdot P` + * - Fréchet mean + - Karcher flow + - Arithmetic mean + - Arithmetic mean + * - Running mean update + - AIRM geodesic + - Linear interpolation + - Linear interpolation + + **Bi-invariant distance.** + The Fréchet variance uses the :math:`(\alpha, \beta)` bi-invariant metric + (Definition 3 and Eq. 3 of the paper): + + .. math:: + + d^2(P, Q) = \alpha \lVert V \rVert_F^2 + + \beta \, g(V)^2 + + where :math:`V` is the tangent representation (log-map) and + :math:`g(V) = \log\det(P)` for AIM or :math:`\operatorname{tr}(V)` + for LEM/LCM. The variance is normalized by :math:`\theta^2` for AIM + and LCM. + + Parameters + ---------- + num_features : int + Size of the SPD matrices (:math:`n \times n`). + metric : {"AIM", "LEM", "LCM"}, default="AIM" + Lie group invariant metric. + theta : float, default=1.0 + Power deformation parameter :math:`\theta`. When + :math:`\theta = 1`, no deformation is applied. + alpha : float, default=1.0 + Frobenius norm weight :math:`\alpha` in the bi-invariant distance. + beta : float, default=0.0 + Trace / log-determinant weight :math:`\beta` in the bi-invariant + distance. Must satisfy :math:`\min(\alpha, \alpha + n\beta) > 0`. + momentum : float, default=0.1 + Momentum :math:`\gamma` for exponential moving average of running + statistics. + eps : float, default=1e-5 + Numerical stability constant :math:`\epsilon` added to the variance + before taking the square root. + n_iter : int, default=1 + Number of Karcher flow iterations for the AIM Fréchet mean. + Ignored by LEM and LCM (which use arithmetic means). + congruence : {"cholesky", "eig"}, default="cholesky" + Implementation of the AIM congruence action (centering/biasing). + ``"cholesky"`` uses the Cholesky factor :math:`L` of :math:`P` to + compute :math:`L X L^\top` (as in the original LieBN paper). + ``"eig"`` uses eigendecomposition-based :math:`M^{-1/2} X M^{-1/2}` + (matching :func:`~spd_learn.functional.spd_centering`). + Both are mathematically equivalent; Cholesky is typically faster, + while eigendecomposition reuses the infrastructure of + :class:`~spd_learn.modules.SPDBatchNormMeanVar`. + Only affects the AIM metric. + device : torch.device or str, optional + Device on which to create parameters and buffers. + dtype : torch.dtype, optional + Data type of parameters and buffers. + + Attributes + ---------- + bias : nn.Parameter + Learnable SPD bias matrix :math:`B \in \mathcal{S}_{++}^n`, + parametrized via :class:`~spd_learn.modules.SymmetricPositiveDefinite`. + Initialized to the identity. + shift : nn.Parameter + Learnable positive scalar :math:`s > 0`, + parametrized via :class:`~spd_learn.modules.PositiveDefiniteScalar`. + Initialized to 1. + running_mean : torch.Tensor + Exponential moving average of the batch Fréchet mean. + running_var : torch.Tensor + Exponential moving average of the batch variance. + + See Also + -------- + :class:`SPDBatchNormMean` : + Mean-only Riemannian batch normalization (AIRM centering without + variance normalization) :cite:p:`brooks2019riemannian`. + :class:`SPDBatchNormMeanVar` : + Full Riemannian batch normalization under the AIRM + :cite:p:`kobler2022spd`. + :func:`~spd_learn.functional.frechet_mean` : + Fréchet mean via Karcher flow (used internally for AIM). + :func:`~spd_learn.functional.lie_group_variance` : + Bi-invariant Fréchet variance computation. + + References + ---------- + .. bibliography:: + :filter: key == "chen2024liebn" + + Examples + -------- + >>> import torch + >>> from spd_learn.modules import SPDBatchNormLie + >>> bn = SPDBatchNormLie(num_features=4, metric="AIM") + >>> X = torch.randn(8, 4, 4, dtype=torch.float64) + >>> X = X @ X.mT + 0.1 * torch.eye(4, dtype=torch.float64) + >>> bn = bn.to(dtype=torch.float64) + >>> Y = bn(X) + >>> Y.shape + torch.Size([8, 4, 4]) + """ + + def __init__( + self, + num_features, + metric="AIM", + theta=1.0, + alpha=1.0, + beta=0.0, + momentum=0.1, + eps=1e-5, + n_iter=1, + congruence="cholesky", + device=None, + dtype=None, + ): + super().__init__() + supported_metrics = ("AIM", "LEM", "LCM") + if metric not in supported_metrics: + raise ValueError( + f"metric must be one of {supported_metrics}, got '{metric}'" + ) + if congruence not in ("cholesky", "eig"): + raise ValueError( + f"congruence must be 'cholesky' or 'eig', got '{congruence}'" + ) + self.num_features = num_features + self.metric = metric + self.theta = theta + self.alpha = alpha + self.beta = beta + self.momentum = momentum + self.eps = eps + self.n_iter = n_iter + self.congruence = congruence + + self.bias = nn.Parameter( + torch.empty(1, num_features, num_features, device=device, dtype=dtype) + ) + self.shift = nn.Parameter(torch.empty((), device=device, dtype=dtype)) + + if metric == "AIM": + self.register_buffer( + "running_mean", + torch.eye(num_features, device=device, dtype=dtype).unsqueeze(0), + ) + else: + self.register_buffer( + "running_mean", + torch.zeros(1, num_features, num_features, device=device, dtype=dtype), + ) + self.register_buffer("running_var", torch.ones((), device=device, dtype=dtype)) + + self.reset_parameters() + self._parametrize() + + @torch.no_grad() + def reset_parameters(self): + self.bias.zero_() + self.bias[0].fill_diagonal_(1.0) + self.shift.fill_(1.0) + + def _parametrize(self): + register_parametrization(self, "bias", SymmetricPositiveDefinite()) + register_parametrization(self, "shift", PositiveDefiniteScalar()) + + # ------------------------------------------------------------------ + # Thin dispatch helpers — each delegates to existing functional ops. + # ------------------------------------------------------------------ + + def _deform(self, X): + """Map SPD matrices to the Lie algebra.""" + if self.metric == "AIM": + return X if self.theta == 1.0 else matrix_power.apply(X, self.theta) + if self.metric == "LEM": + return matrix_log.apply(X) + # LCM + Xp = X if self.theta == 1.0 else matrix_power.apply(X, self.theta) + return cholesky_log.apply(Xp) + + def _inv_deform(self, S): + """Map from the Lie algebra back to SPD matrices.""" + if self.metric == "AIM": + return S if self.theta == 1.0 else matrix_power.apply(S, 1.0 / self.theta) + if self.metric == "LEM": + return matrix_exp.apply(S) + # LCM + spd = ensure_sym(cholesky_exp.apply(S)) + return spd if self.theta == 1.0 else matrix_power.apply(spd, 1.0 / self.theta) + + def _translate(self, X, P, inverse=False): + """Group translation (centering / biasing) in the Lie algebra.""" + if self.metric == "AIM": + if self.congruence == "cholesky": + return spd_cholesky_congruence(X, P, inverse=inverse) + # Eigendecomposition path + if inverse: + return spd_centering(X, matrix_inv_sqrt.apply(P)) + return spd_rebiasing(X, matrix_sqrt.apply(P)) + return X - P if inverse else X + P + + def _frechet_mean(self, X_def): + """Fréchet mean in the deformed space.""" + if self.metric == "AIM": + return frechet_mean(X_def, max_iter=self.n_iter) + return X_def.detach().mean(dim=0, keepdim=True) + + def _scale(self, X, var): + """Variance normalization in the Lie algebra.""" + factor = self.shift / (var + self.eps).sqrt() + if self.metric == "AIM": + return log_euclidean_scalar_multiply(factor, X) + return X * factor + + # ------------------------------------------------------------------ + # Running statistics & forward + # ------------------------------------------------------------------ + + def _update_running_stats(self, batch_mean, batch_var): + with torch.no_grad(): + if self.metric == "AIM": + self.running_mean = airm_geodesic( + self.running_mean, batch_mean, self.momentum + ) + else: + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * batch_mean + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * batch_var + + def forward(self, X): + X_def = self._deform(X) + bias_def = self._deform(self.bias) + + if self.training: + batch_mean = self._frechet_mean(X_def) + X_centered = self._translate(X_def, batch_mean, inverse=True) + if X.shape[0] > 1: + batch_var = lie_group_variance( + X_centered.detach(), + self.metric, + self.alpha, + self.beta, + self.theta, + ) + X_scaled = self._scale(X_centered, batch_var) + else: + batch_var = self.running_var.clone() + X_scaled = X_centered + self._update_running_stats(batch_mean.detach(), batch_var.detach()) + else: + X_centered = self._translate(X_def, self.running_mean, inverse=True) + X_scaled = self._scale(X_centered, self.running_var) + + X_biased = self._translate(X_scaled, bias_def, inverse=False) + return self._inv_deform(X_biased) + + def extra_repr(self): + return ( + f"num_features={self.num_features}, metric={self.metric}, " + f"theta={self.theta}, alpha={self.alpha}, beta={self.beta}, " + f"momentum={self.momentum}, congruence={self.congruence}" + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 0e10aa1..092d54b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -21,6 +21,7 @@ "SPDBatchNormMean": dict(num_features=10), "BatchReNorm": dict(num_features=10), "SPDBatchNormMeanVar": dict(num_features=10), + "SPDBatchNormLie": dict(num_features=10), "PatchEmbeddingLayer": dict(n_chans=10, n_patches=2), "BiMapIncreaseDim": dict(in_features=10, out_features=20), "Shrinkage": dict(n_chans=10), @@ -141,6 +142,8 @@ def test_module_dtype(module_name, dtype, device): pytest.skip( "PositiveDefiniteScalar is a scalar parametrization, not a matrix layer." ) + if module_name == "SPDBatchNormLie" and dtype.is_complex: + pytest.skip("SPDBatchNormLie only supports real-valued SPD matrices.") module_class = getattr(spd_learn.modules, module_name) mandatory_param = mandatory_parameters_per_module.get(module_name, {}) diff --git a/tests/test_liebn.py b/tests/test_liebn.py new file mode 100644 index 0000000..8822f0d --- /dev/null +++ b/tests/test_liebn.py @@ -0,0 +1,313 @@ +"""Tests for Lie Group Batch Normalization (LieBN) for SPD matrices. + +Verifies the theoretical guarantees from Chen et al., ICLR 2024 +(Proposition 4.2): + - Mean property: after centering+biasing with bias=I, output Frechet mean ≈ I + - Variance property: after scaling with shift=1, dispersion ≈ 1.0 + - Running statistics converge to population statistics +""" + +from math import sqrt + +import pytest +import torch + +from spd_learn.functional import ( + ensure_sym, + matrix_exp, + matrix_log, + vec_to_sym, +) +from spd_learn.functional.batchnorm import karcher_mean_iteration +from spd_learn.modules import SPDBatchNormLie + + +DTYPE = torch.float64 + +# --------------------------------------------------------------------------- +# Data fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def simulated_data(): + """Generate SPD data with known Frechet mean for testing. + + Strategy: zero-mean tangent vectors -> matrix_exp -> SPD at Identity, + then apply linear mixing x = A z A^T so Frechet mean = A A^T. + """ + ndim = 10 + nobs = 128 + generator = torch.Generator().manual_seed(42) + + # Zero-mean tangent vectors -> SPD matrices centered at Identity + logz = vec_to_sym( + torch.randn((nobs, ndim * (ndim + 1) // 2), generator=generator, dtype=DTYPE) + ) + logz = logz - logz.mean(dim=0, keepdim=True) + z = matrix_exp.apply(logz) + + # Linear mixing model: shifts Frechet mean to A @ A^T + eps = 0.1 + forward_model = ( + torch.rand((ndim, ndim), generator=generator, dtype=DTYPE) - 0.5 + ) * (1 - eps) + eps * torch.eye(ndim, dtype=DTYPE) + x = forward_model @ z @ forward_model.mT + + # Analytic Frechet mean (by invariance) + x_mean_expected = (forward_model @ forward_model.mT).unsqueeze(0) + + return x, x_mean_expected, ndim, nobs + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +METRICS = ["AIM", "LEM", "LCM"] +CONGRUENCES = ["cholesky", "eig"] + + +@pytest.mark.parametrize( + "metric,theta,atol", + [ + ("AIM", 1.0, 1e-10), + ("LCM", 1.0, 1e-5), # Cholesky + log/exp introduces numerical error + ("LCM", 0.5, 1e-4), # Additional matrix_power roundtrip error + ], +) +def test_deform_inv_deform_roundtrip(simulated_data, metric, theta, atol): + """_inv_deform(_deform(X)) should recover X.""" + x, _, ndim, _ = simulated_data + layer = SPDBatchNormLie(ndim, metric=metric, theta=theta, dtype=DTYPE) + + X_def = layer._deform(x) + X_recovered = layer._inv_deform(X_def) + + assert torch.allclose(X_recovered, x, atol=atol, rtol=0.0) + + +@pytest.mark.parametrize("congruence", CONGRUENCES) +@pytest.mark.parametrize("metric", METRICS) +def test_post_normalization_mean(simulated_data, metric, congruence): + """After LieBN forward (bias=I, shift=1), codomain mean should be neutral. + + - AIM: Karcher mean of output ≈ Identity + - LEM/LCM: arithmetic mean of deformed output ≈ zero matrix + """ + x, _, ndim, nobs = simulated_data + layer = SPDBatchNormLie( + ndim, metric=metric, n_iter=64, congruence=congruence, dtype=DTYPE + ) + layer.train() + + with torch.no_grad(): + output = layer(x) + + tol = 2 * sqrt(1.0 / nobs) + + if metric == "AIM": + # Compute Karcher mean of output + mean = output.mean(dim=0, keepdim=True) + for _ in range(64): + mean = karcher_mean_iteration(output, mean, detach=True) + identity = torch.eye(ndim, dtype=DTYPE).unsqueeze(0) + assert torch.allclose(mean, identity, atol=tol, rtol=0.0), ( + f"AIM: Karcher mean of output deviates from Identity by " + f"{(mean - identity).abs().max().item():.6f}" + ) + else: + # In codomain, mean should be ≈ zero + output_def = layer._deform(output) + codomain_mean = output_def.mean(dim=0, keepdim=True) + zeros = torch.zeros_like(codomain_mean) + assert torch.allclose(codomain_mean, zeros, atol=tol, rtol=0.0), ( + f"{metric}: codomain mean deviates from zero by " + f"{codomain_mean.abs().max().item():.6f}" + ) + + +@pytest.mark.parametrize("metric", METRICS) +def test_post_normalization_variance(simulated_data, metric): + """After LieBN forward (shift=1), output variance should be ≈ 1.0. + + Theoretical: shift^2 * v^2 / (v^2 + eps). With shift=1 and large v^2, + this is close to 1.0. + """ + x, _, ndim, nobs = simulated_data + layer = SPDBatchNormLie(ndim, metric=metric, n_iter=64, dtype=DTYPE) + layer.train() + + with torch.no_grad(): + output = layer(x) + + # Compute variance of output in the same way as SPDBatchNormLie._frechet_variance + # but on the re-centered output + output_def = layer._deform(output) + if metric == "AIM": + output_mean = output_def.mean(dim=0, keepdim=True) + for _ in range(64): + output_mean = karcher_mean_iteration(output_def, output_mean, detach=True) + L = torch.linalg.cholesky(output_mean) + Y = torch.linalg.solve_triangular(L, output_def, upper=False) + centered = ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT) + logX = matrix_log.apply(centered) + frob_sq = (logX * logX).sum(dim=(-2, -1)) + output_var = frob_sq.mean() + else: + centered = output_def - output_def.mean(dim=0, keepdim=True) + frob_sq = (centered * centered).sum(dim=(-2, -1)) + output_var = frob_sq.mean() + + # Expected: shift^2 * v^2 / (v^2 + eps) ≈ 1.0 + tol = 3 * sqrt(1.0 / nobs) + assert abs(output_var.item() - 1.0) < tol, ( + f"{metric}: output variance = {output_var.item():.6f}, expected ≈ 1.0" + ) + + +@pytest.mark.parametrize("metric", METRICS) +def test_running_stats_single_batch(simulated_data, metric): + """With momentum=1.0, running stats should match batch stats exactly.""" + x, _, ndim, nobs = simulated_data + layer = SPDBatchNormLie(ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE) + layer.train() + + with torch.no_grad(): + layer(x) + + # Independently compute batch statistics + X_def = layer._deform(x) + if metric == "AIM": + expected_mean = X_def.mean(dim=0, keepdim=True) + for _ in range(64): + expected_mean = karcher_mean_iteration(X_def, expected_mean, detach=True) + else: + expected_mean = X_def.mean(dim=0, keepdim=True) + + tol = sqrt(1.0 / nobs) + assert torch.allclose(layer.running_mean, expected_mean, atol=tol, rtol=0.0), ( + f"{metric}: running_mean deviates from batch mean by " + f"{(layer.running_mean - expected_mean).abs().max().item():.6f}" + ) + + # Compute expected variance from centered data + if metric == "AIM": + L = torch.linalg.cholesky(expected_mean) + Y = torch.linalg.solve_triangular(L, X_def, upper=False) + centered = ensure_sym(torch.linalg.solve_triangular(L, Y.mT, upper=False).mT) + logX = matrix_log.apply(centered) + frob_sq = (logX * logX).sum(dim=(-2, -1)) + expected_var = frob_sq.mean() + else: + centered = X_def - expected_mean + frob_sq = (centered * centered).sum(dim=(-2, -1)) + expected_var = frob_sq.mean() + + assert torch.allclose(layer.running_var, expected_var, atol=tol, rtol=0.0), ( + f"{metric}: running_var = {layer.running_var.item():.6f}, " + f"expected = {expected_var.item():.6f}" + ) + + +@pytest.mark.parametrize("metric", METRICS) +def test_running_stats_convergence(simulated_data, metric): + """Running stats should converge to population stats over mini-batches.""" + x, _, ndim, nobs = simulated_data + layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE) + + # Full-batch reference statistics (high precision) + with torch.no_grad(): + ref_layer = SPDBatchNormLie( + ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE + ) + ref_layer.train() + ref_layer(x) + ref_mean = ref_layer.running_mean.clone() + ref_var = ref_layer.running_var.clone() + + # Train with mini-batches and decaying momentum + ds = torch.utils.data.TensorDataset(x) + loader = torch.utils.data.DataLoader(ds, batch_size=nobs // 4, drop_last=True) + + layer.train() + n_epochs = 64 // len(loader) * 4 + for epoch in range(n_epochs): + layer.momentum = 1 / (epoch + 1) + for batch in loader: + with torch.no_grad(): + layer(batch[0]) + + tol = 5 * sqrt(1.0 / nobs) + assert torch.allclose(layer.running_mean, ref_mean, atol=tol, rtol=0.0), ( + f"{metric}: running_mean did not converge. " + f"Max deviation: {(layer.running_mean - ref_mean).abs().max().item():.6f}" + ) + assert torch.allclose(layer.running_var, ref_var, atol=tol, rtol=0.05), ( + f"{metric}: running_var did not converge. " + f"running={layer.running_var.item():.4f}, ref={ref_var.item():.4f}" + ) + + +@pytest.mark.parametrize("metric", METRICS) +def test_gradient_flow(simulated_data, metric): + """Verify gradients flow through LieBN to input and parameters.""" + x, _, ndim, _ = simulated_data + # Use a small batch to keep computation fast + x_small = x[:8].clone().requires_grad_(True) + + layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE) + layer.train() + + output = layer(x_small) + loss = output.sum() + loss.backward() + + # Input gradient + assert x_small.grad is not None, f"{metric}: no gradient on input" + assert x_small.grad.abs().sum() > 0, f"{metric}: zero gradient on input" + + # Bias parameter gradient (underlying unconstrained parameter) + bias_param = layer.parametrizations.bias.original + assert bias_param.grad is not None, f"{metric}: no gradient on bias" + assert bias_param.grad.abs().sum() > 0, f"{metric}: zero gradient on bias" + + # Shift parameter gradient + shift_param = layer.parametrizations.shift.original + assert shift_param.grad is not None, f"{metric}: no gradient on shift" + assert shift_param.grad.abs().sum() > 0, f"{metric}: zero gradient on shift" + + +@pytest.mark.parametrize("metric", METRICS) +def test_default_initialization(metric): + """Verify default parameter initialization.""" + ndim = 4 + layer = SPDBatchNormLie(ndim, metric=metric, dtype=DTYPE) + + # Bias should be Identity + identity = torch.eye(ndim, dtype=DTYPE).unsqueeze(0) + assert torch.allclose(layer.bias, identity, atol=1e-10), ( + f"{metric}: bias not initialized to Identity" + ) + + # Shift should be 1.0 + assert torch.allclose(layer.shift, torch.ones((), dtype=DTYPE), atol=1e-10), ( + f"{metric}: shift not initialized to 1.0" + ) + + # Running mean: Identity for AIM, zeros for LEM/LCM + if metric == "AIM": + assert torch.allclose(layer.running_mean, identity, atol=1e-10) + else: + assert torch.allclose( + layer.running_mean, torch.zeros(1, ndim, ndim, dtype=DTYPE), atol=1e-10 + ) + + # Running var should be 1.0 + assert torch.allclose(layer.running_var, torch.ones((), dtype=DTYPE), atol=1e-10) + + # Verify dtype propagated correctly + assert layer.bias.dtype == DTYPE + assert layer.shift.dtype == DTYPE + assert layer.running_mean.dtype == DTYPE + assert layer.running_var.dtype == DTYPE