![QF-Logo](https://quantumformalism.academy/img/qf-up.png)

# Equivariant Neural Networks on Homogeneous Spaces
### Interactive Exploration of Group Actions and Equivariant Maps

This notebook supplements **Lecture 8: Geometric Deep Learning**, where we discussed:

- **Homogeneous spaces** as domains for GDL models (e.g., $S^2 \cong SO(3)/SO(2)$).
- **Group actions** on these spaces and associated feature maps (like vector fields).
- The concept of **induced representations** and **associated bundles**.
- **G-Equivariant Convolutional Neural Networks (G-CNNs)** on homogeneous spaces.

### What You Will Learn
- How to represent the sphere $S^2$ as a **homogeneous space** $SO(3)/SO(2)$.
- How the **group $SO(3)$ acts** on points and feature maps (vector fields) on the sphere (equivariance).
- The connection between **vector fields**, **associated bundles**, and **equivariant maps**.
- The role of **Spherical Harmonics** as an equivariant basis for functions on $S^2$.
- How **linear equivariant maps** in harmonic space relate to **spherical convolutions** (Schur's Lemma & Theorem 3.1 from Cohen et al. 2019).
- A practical example of building and testing an **SO(3)-equivariant layer** using the `e3nn` library.

Use the code and visualizations to build intuition for these geometric concepts and their application in Geometric Deep Learning!

## Why This Matters

Many real-world datasets possess underlying symmetries captured by Lie groups acting on non-Euclidean domains (like spheres, hyperbolic spaces, or other manifolds). Geometric Deep Learning aims to build neural networks that respect these symmetries.

- **Homogeneous spaces** provide a powerful framework for describing these symmetric domains.
- Understanding **group actions on feature maps** (like scalar or vector fields) is crucial for defining equivariant operations (like convolutions).
- **Spherical CNNs**, used in areas like astrophysics (CMB analysis), molecular modeling, and 360° vision, are a prime example of G-CNNs on the homogeneous space $S^2$.
- The theory connects abstract concepts like **induced representations** and **Schur's Lemma** directly to the practical implementation of **equivariant convolutions**.

This notebook uses the sphere as a concrete example to illustrate the core mathematical ideas behind these advanced GDL models.

## Setup: Importing Libraries
We'll use standard scientific libraries for computation and visualization, along with specialized libraries for equivariant deep learning.

In [None]:
import numpy as np # For linear algebra
from scipy.linalg import expm, block_diag # Matrix exponentials and block diagonal matrices
from scipy.spatial.transform import Rotation as R # For handling SO(3) rotations
from scipy.special import sph_harm # For spherical harmonics
import matplotlib.pyplot as plt # For visualizations
from mpl_toolkits.mplot3d import Axes3D # For 3D plotting
import torch # PyTorch for e3nn example
import e3nn # Equivariant neural networks library
from e3nn import o3

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid') # Updated style
plt.rcParams['figure.figsize'] = (8, 8) # Default figure size

## Section 1: Homogeneous Spaces - The Sphere $S^2$ as $SO(3)/SO(2)$

A **homogeneous space** is a manifold $M$ on which a Lie group $G$ acts transitively. This means any point can be reached from any other point via a group action. Many important homogeneous spaces can be represented as quotient spaces $G/H$, where $H$ is the **stabilizer subgroup** of a point $p \in M$ (i.e., $H = \{ h \in G \mid h \cdot p = p \}$).

The 2-sphere $S^2$ is a classic example. The rotation group $SO(3)$ acts transitively on $S^2$. The stabilizer of the north pole $p = (0, 0, 1)$ is the group of rotations around the z-axis, which is isomorphic to $SO(2)$. Therefore, we have the identification:

$$ S^2 \cong SO(3) / SO(2) $$

This means every point $x \in S^2$ can be written as $x = g \cdot p$ for some $g \in SO(3)$. The set of all such points generated from $p$ is the **orbit** of $p$, which is the entire sphere $S^2$ due to transitivity.

Below, we visualize the orbit of the north pole under a family of $SO(3)$ rotations.

In [None]:
# Define the north pole
north_pole = np.array([0, 0, 1])

# Generate a few rotation matrices from SO(3)
# We use simple z-axis rotations then y-axis rotations for better coverage.
num_points = 15
orbit_points = []
rotations_list = []
for theta in np.linspace(0, np.pi, num_points // 3):
    for phi in np.linspace(0, 2*np.pi, num_points // 3):
        # Apply rotation around Y then Z (Euler angles ZYZ convention is common)
        # Simpler: Rotate north pole by theta around Y, then phi around Z
        rot_y = R.from_euler('y', theta).as_matrix()
        rot_z = R.from_euler('z', phi).as_matrix()
        g = rot_z @ rot_y # Combined rotation
        orbit_points.append(g @ north_pole)
        rotations_list.append(g)

# Visualize the orbit
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')

# Draw the unit sphere for context
u_sphere = np.linspace(0, 2 * np.pi, 50)
v_sphere = np.linspace(0, np.pi, 50)
x_sphere = np.outer(np.cos(u_sphere), np.sin(v_sphere))
y_sphere = np.outer(np.sin(u_sphere), np.sin(v_sphere))
z_sphere = np.outer(np.ones(np.size(u_sphere)), np.cos(v_sphere))
ax.plot_surface(x_sphere, y_sphere, z_sphere, color='lightgray', alpha=0.2, zorder=1)

# Plot orbit points (generated by rotating the north pole)
ax.scatter(north_pole[0], north_pole[1], north_pole[2], color='k', s=150, label='North Pole (p)', zorder=5)
orbit_x = [pt[0] for pt in orbit_points]
orbit_y = [pt[1] for pt in orbit_points]
orbit_z = [pt[2] for pt in orbit_points]
ax.scatter(orbit_x, orbit_y, orbit_z, color='r', s=50, label='Orbit points g.p', zorder=4)

# Add arrows from origin to points
for pt in orbit_points:
    ax.quiver(0, 0, 0, pt[0], pt[1], pt[2], color='b', alpha=0.5, length=0.9, arrow_length_ratio=0.1, zorder=3)

ax.set_title("$S^2$ as the Orbit $SO(3) \cdot p$")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
ax.set_box_aspect([1,1,1]) # Equal aspect ratio
plt.show()

## Section 2: Feature Maps on Homogeneous Spaces - Tangent Vectors & Associated Bundles

Feature maps on homogeneous spaces, like scalar fields or vector fields, often transform in specific ways under the group action. This transformation property is called **equivariance**.

A **vector field** on $S^2$ assigns a tangent vector $s(x) \in T_x S^2$ to each point $x \in S^2$. The collection of all tangent spaces forms the **tangent bundle** $TS^2$.

For $S^2 = G/H = SO(3)/SO(2)$, the tangent space at the north pole $T_p S^2$ is isomorphic to $\mathbb{R}^2$ (the xy-plane perpendicular to $p$). The tangent bundle $TS^2$ can be constructed as an **associated vector bundle**:

$$ TS^2 \cong SO(3) \times_{SO(2)} \mathbb{R}^2 $$

Here, $\mathbb{R}^2$ is the **fiber**, representing the tangent space at the identity coset (north pole), and $H=SO(2)$ acts on this fiber $\mathbb{R}^2$ via its standard 2D rotation representation $\rho$. This means the fiber carries a *representation* of the stabilizer subgroup $H$.

A vector field $s: S^2 \to TS^2$ is a **section** of this associated bundle. It can be **lifted** to an $SO(2)$-equivariant map $\tilde{s}: SO(3) \to \mathbb{R}^2$ satisfying $\tilde{s}(gh) = \rho(h^{-1}) \tilde{s}(g)$ for $h \in H=SO(2)$.

We can visualize this concept by defining tangent vectors at the north pole (in the fiber $\mathbb{R}^2$) and then "pushing" them to other points on the sphere using the differential of the group action $d_p g: T_p S^2 \to T_{g \cdot p} S^2$. This differential essentially tells us how the tangent space itself transforms under the group action.

In [None]:
# We construct tangent vectors at the north pole in the tangent plane
# T_p S^2 is the xy-plane when p = (0, 0, 1)

# Define two orthonormal tangent vectors at the north pole (these live in the fiber R^2)
v1_north_pole = np.array([1, 0, 0]) # Corresponds to e_x in T_p S^2
v2_north_pole = np.array([0, 1, 0]) # Corresponds to e_y in T_p S^2

# For a rotation g ∈ SO(3), the differential d_p g acts on v ∈ T_p S^2
# In this case (linear action of SO(3) on R^3), the differential is just the group element itself.
# d_p g (v) = g @ v. This gives a tangent vector at g·p.
# We define a simple vector field s(g.p) = d_p g (v1_north_pole)

def pushforward_vector(rotation_matrix, vector_at_p):
    """Computes d_p g (v) = g @ v for the SO(3) action on R^3"""
    return rotation_matrix @ vector_at_p

# Compute the tangent vectors at the previously calculated orbit points
# This represents evaluating our simple vector field s at those points.
tangent_vectors_on_orbit = [pushforward_vector(rot, v1_north_pole) for rot in rotations_list]

# Visualize the tangent vectors attached to the orbit points
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x_sphere, y_sphere, z_sphere, color='lightgray', alpha=0.2, zorder=1)

ax.scatter(north_pole[0], north_pole[1], north_pole[2], color='k', s=150, label='North Pole (p)', zorder=5)
ax.quiver(north_pole[0], north_pole[1], north_pole[2],
            v1_north_pole[0], v1_north_pole[1], v1_north_pole[2],
            color='darkgreen', length=0.3, label='$v_1 \in T_p S^2$', zorder=6)

# Plot orbit points and their corresponding tangent vectors
for i, (pt, vec) in enumerate(zip(orbit_points, tangent_vectors_on_orbit)):
    label = 'Tangent Vectors $s(g \cdot p) = d_p g(v_1)$' if i == 0 else ""
    ax.quiver(pt[0], pt[1], pt[2], vec[0], vec[1], vec[2], 
              color='g', length=0.3, normalize=False, label=label, zorder=4)
    ax.scatter(pt[0], pt[1], pt[2], color='r', s=50, zorder=3)

ax.set_title("Tangent Vectors (Features) on $S^2$")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
ax.set_box_aspect([1,1,1])
plt.show()

## Section 3: SO(3) Action on Vector Fields (Equivariance / Push-Pull)

Now, let's see how the entire vector field $s$ transforms under the action of an element $g \in SO(3)$. An equivariant map (like a G-CNN layer) must respect this transformation.

The group action on sections (like vector fields), also called the **induced representation** on the space of sections $\Gamma(TS^2)$, is defined using the "push-pull" mechanism discussed in the lecture:

$$ (g \cdot s)(x) = d_{g^{-1}x} g \left( s(g^{-1}x) \right) $$

For matrix groups like $SO(3)$ acting linearly on vectors in $\mathbb{R}^3$, the differential $d_{y}g$ is just the matrix $g$ itself. This simplifies the action to:

$$ (g \cdot s)(x) = g \cdot s(g^{-1}x) $$

This formula tells us how to find the vector for the *transformed* field $(g \cdot s)$ at point $x$:
1. Find the point $y = g^{-1}x$ where the field originated before the transformation.
2. Get the original vector field's value at that point: $s(y) = s(g^{-1}x)$.
3. Apply the group transformation $g$ (matrix multiplication) to that vector: $g \cdot s(g^{-1}x)$. This aligns the vector correctly in the tangent space $T_x S^2$.

We simulate this by applying a new rotation $g$ (let's call it `g_action`) to the vector field we visualized above. We compute the *new* vectors $(g \cdot s)(g \cdot x_i)$ at the *new* points $g \cdot x_i$.

In [None]:
# Choose a new rotation g_act ∈ SO(3) for the group action on the field
g_action = R.from_euler('y', np.pi/4).as_matrix()

# The action transforms both the base point and the vector.
# For a point x and vector s(x), the transformed point is g_act . x 
# and the transformed vector at that new point is g_act . s(x).
# Note: This simplified view works because d_x g = g for linear actions.
# The formula (g.s)(x) = g . s(g^{-1}x) finds the vector at point x in the *new* field.
# Here we visualize the vectors g . s(x_i) attached to the points g . x_i

def apply_group_action_to_field(g_act, points, vectors):
    """
    Applies the SO(3) action g_act to a vector field represented by points and vectors.
    Input: g_act, points x_i, and corresponding vectors s(x_i).
    Output: Transformed points y_i = g_act . x_i and corresponding transformed vectors 
            v_i = g_act . s(x_i) at those new points y_i.
            This represents the field (g_act s) evaluated at the transformed points.
    """
    transformed_points = [g_act @ pt for pt in points]
    transformed_vectors = [g_act @ vec for vec in vectors] # Apply g_act to s(x_i)
    return transformed_points, transformed_vectors

# Apply SO(3) action g_action to the previously defined vector field
# orbit_points are the original x_i
# tangent_vectors_on_orbit are the original s(x_i)
transformed_orbit_points, transformed_tangent_vectors = apply_group_action_to_field(
    g_action, orbit_points, tangent_vectors_on_orbit
)

# Visualize the transformed vector field
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x_sphere, y_sphere, z_sphere, color='lightgray', alpha=0.2, zorder=1)

# Plot the new basepoints and the transformed vectors
for i, (pt, vec) in enumerate(zip(transformed_orbit_points, transformed_tangent_vectors)):
    label_pt = 'Transformed points $g \cdot x$' if i==0 else ""
    label_vec = 'Transformed Vectors $(g \cdot s)(g \cdot x) = g \cdot s(x)$' if i == 0 else ""
    ax.quiver(pt[0], pt[1], pt[2], vec[0], vec[1], vec[2], 
              color='orange', length=0.3, normalize=False, label=label_vec, zorder=4)
    ax.scatter(pt[0], pt[1], pt[2], color='purple', s=50, label=label_pt, zorder=3)

ax.set_title("Action of $g \in SO(3)$ on the Vector Field $s$")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
ax.set_box_aspect([1,1,1])
plt.show()

## Section 4: Harmonic Analysis on the Sphere - The Equivariant Basis

Features on the sphere (scalar fields, vector fields, etc.) can be decomposed using **spherical harmonics**, $Y_l^m(\theta, \phi)$. These are special functions which form an orthonormal basis for the space of square-integrable complex-valued functions on the sphere, $L^2(S^2,\mathbb{C})$, and are eigenfunctions of the Laplacian on the sphere.

Crucially, for a fixed degree $l$ (a non-negative integer), the space spanned by the $2l+1$ harmonics $\{ Y_l^m \}_{m=-l}^l$ transforms *among itself* under rotations $R \in SO(3)$. This space forms an **irreducible representation** (irrep) of $SO(3)$, often denoted $D^l$. This means that rotations mix harmonics of the same degree $l$ but never mix harmonics of different degrees.

$$ (R \cdot Y_l^m)(\theta, \phi) = Y_l^m(R^{-1}(\theta, \phi)) = \sum_{m'=-l}^{l} D^l_{m'm}(R) Y_l^{m'}(\theta, \phi) $$

where $D^l(R)$ is the $(2l+1) \times (2l+1)$ Wigner-D matrix for rotation $R$.

This property makes spherical harmonics an **SO(3)-equivariant basis**. Any function $f: S^2 \to \mathbb{C}$ can be written as a sum over these basis functions:
$$ f(\theta, \phi) = \sum_{l=0}^{\infty} \sum_{m=-l}^{l} \hat{f}_{lm} Y_l^m(\theta, \phi) $$
The complex numbers $\hat{f}_{lm}$ are the **spherical harmonic coefficients** or the **spectrum** of the function $f$. The action of $SO(3)$ on the function $f$ translates directly to an action on its coefficients $\hat{f}_{lm}$.

Below, we visualize a simple scalar field constructed from a low-degree spherical harmonic, representing a single basis function.

In [None]:
# Define spherical coordinates grid (azimuthal phi, polar theta)
phi_grid, theta_grid = np.meshgrid(u_sphere, v_sphere) # Reuse sphere grid from Sec 1

# Create a scalar field on the sphere using a low-degree spherical harmonic
l_vis = 2 # Degree
m_vis = 1 # Order

# Evaluate the spherical harmonic Y_lm(theta, phi)
Y_lm = sph_harm(m_vis, l_vis, phi_grid, theta_grid)
Y_lm_real = Y_lm.real # Take the real part for visualization

# Map spherical coordinates to Cartesian coordinates for plotting
X_cart = np.sin(theta_grid) * np.cos(phi_grid)
Y_cart = np.sin(theta_grid) * np.sin(phi_grid)
Z_cart = np.cos(theta_grid)

# Plot the spherical harmonic as a scalar field (color) on S^2
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')

# Normalize colors for the surface plot
norm = plt.Normalize(vmin=Y_lm_real.min(), vmax=Y_lm_real.max())
cmap = plt.cm.coolwarm
surf = ax.plot_surface(X_cart, Y_cart, Z_cart, 
                       facecolors=cmap(norm(Y_lm_real)), 
                       rstride=1, cstride=1, antialiased=False, shade=False)

ax.set_title(f"Spherical Harmonic $Y_{{{l_vis}}}^{{{m_vis}}}$ (Real Part) on $S^2$")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_box_aspect([1,1,1])
# Add a color bar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # You need this line for the color bar to work properly
cbar = fig.colorbar(sm, ax=ax, shrink=0.6, aspect=10)
cbar.set_label('Value of $Re(Y_l^m)$')
plt.show()

## Section 5: Numerical Harmonic Representation & SHT (Conceptual)

For practical computation, we need to represent functions on the sphere and their harmonic transforms numerically.

1.  **Spatial Domain**: A function $f: S^2 \to \mathbb{R}$ (or $\mathbb{C}$) is typically represented by its values on a discrete grid of points $(\theta_i, \phi_i)$ on the sphere.
2.  **Harmonic Domain**: The function is represented by its spherical harmonic coefficients $\hat{f}_{lm}$ up to a maximum degree $L_{max}$.
    $$ f(\theta, \phi) \approx \sum_{l=0}^{L_{max}} \sum_{m=-l}^{l} \hat{f}_{lm} Y_l^m(\theta, \phi) $$
3.  **Spherical Harmonic Transform (SHT)**: This is the process of computing the coefficients $\hat{f}_{lm}$ from the function values $f(\theta_i, \phi_i)$. It's the spherical analog of the Fourier Transform.
4.  **Inverse SHT**: This is the process of reconstructing the function $f(\theta, \phi)$ from its coefficients $\hat{f}_{lm}$ (using the summation formula above).

**Numerical Representation of Coefficients**: The set of coefficients $\{ \hat{f}_{lm} \}_{l=0...L_{max}, m=-l...l}$ can be stored as a flat vector. A standard ordering groups coefficients by degree $l$, and within each $l$, by order $m$ (e.g., $m=-l, ..., 0, ..., l$). The total number of complex coefficients up to $L_{max}$ is $\sum_{l=0}^{L_{max}} (2l+1) = (L_{max}+1)^2$.

Let's define a simple signal as a sum of a few harmonics and visualize it using the inverse SHT.

In [None]:
# Define a simple spherical function f as a weighted sum of known harmonics
# (Coefficients in the harmonic domain)
input_coeffs = {
    (0, 0): 0.5 + 0j,
    (1, -1): 0.2 + 0.1j,
    (1, 0): -0.3 + 0j,
    (1, 1): 0.2 - 0.1j,
    (2, -2): 0,
    (2, -1): 0.4 - 0.2j,
    (2, 0): 0.5 + 0j,
    (2, 1): 0,
    (2, 2): 0.3 + 0.3j,
    (3, -3): 0,
    (3, -2): 0,
    (3, -1): 0,
    (3, 0): 0,
    (3, 1): 0,
    (3, 2): 0,
    (3, 3): 0,
}
# Ensure complex conjugate symmetry for real signals if needed, not enforced here.
lmax = max(l for l, m in input_coeffs.keys())
print(f"Signal defined with L_max = {lmax}")

def generate_signal_from_coeffs(theta, phi, coeffs):
    """Synthesizes a signal on the sphere from harmonic coefficients (Inverse SHT)."""
    signal = np.zeros(theta.shape, dtype=np.complex128)
    for (l, m), weight in coeffs.items():
        if weight != 0:
             signal += weight * sph_harm(m, l, phi, theta)
    # For visualization, we often take the real part, but keep complex for processing
    return signal

original_signal_complex = generate_signal_from_coeffs(theta_grid, phi_grid, input_coeffs)
original_signal_real = original_signal_complex.real

# --- Represent coefficients as a flat numpy vector ---
n_coeffs = (lmax + 1)**2
f_vec = np.zeros(n_coeffs, dtype=np.complex128)
coeff_indices = {}
idx = 0
for l in range(lmax + 1):
    for m in range(-l, l + 1):
        coeff_indices[(l, m)] = idx
        if (l, m) in input_coeffs:
            f_vec[idx] = input_coeffs[(l, m)]
        idx += 1
        
print(f"Total number of coefficients (up to L_max={lmax}): {n_coeffs}")
print(f"Shape of coefficient vector f_vec: {f_vec.shape}")
# print("Coefficient vector f_vec (first few elements):", f_vec[:5])

# --- Visualize the original signal (Inverse SHT) ---
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
norm = plt.Normalize(vmin=original_signal_real.min(), vmax=original_signal_real.max())
cmap = plt.cm.viridis
surf = ax.plot_surface(X_cart, Y_cart, Z_cart, facecolors=cmap(norm(original_signal_real)), 
                       rstride=1, cstride=1, antialiased=False, shade=False)
ax.set_title("Original Signal $f$ (Sum of Harmonics)")
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_box_aspect([1,1,1])
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, shrink=0.6, aspect=10)
cbar.set_label('Value of $Re(f)$')
plt.show()

# We can now think about linear maps acting on the coefficient vector f_vec.

## Section 6: Linear Maps & Equivariance in Harmonic Space

Consider a linear map $\Psi$ acting on functions on the sphere. In the harmonic domain, this corresponds to a linear map acting on the coefficient vectors. Let $f$ be the input signal with coefficient vector $\mathbf{f}$ and $g$ be the output signal with coefficient vector $\mathbf{g}$. The linear map is represented by a matrix $\mathbf{\Psi}$ such that:
$$ \mathbf{g} = \mathbf{\Psi} \mathbf{f} $$

**SO(3) Action**: A rotation $R \in SO(3)$ acts on the function $f$, which corresponds to transforming its coefficient vector $\mathbf{f}$. This transformation is block-diagonal in the harmonic basis, with blocks given by the Wigner-D matrices $D^l(R)$. Let $\mathbf{D}(R)$ be the matrix representing this action on the flat coefficient vector $\mathbf{f}$.
$$ (R \cdot f) \longleftrightarrow \mathbf{D}(R) \mathbf{f} $$

**Equivariance Condition**: The linear map $\mathbf{\Psi}$ is SO(3)-equivariant if applying the map and then rotating the output is the same as rotating the input and then applying the map:
$$ \mathbf{D}(R) (\mathbf{\Psi} \mathbf{f}) = \mathbf{\Psi} (\mathbf{D}(R) \mathbf{f}) \quad \forall R \in SO(3), \forall \mathbf{f} $$
This means the map matrix $\mathbf{\Psi}$ must commute with all rotation matrices $\mathbf{D}(R)$:
$$ \mathbf{D}(R) \mathbf{\Psi} = \mathbf{\Psi} \mathbf{D}(R) $$

**Schur's Lemma**: Since the Wigner-D matrices $D^l(R)$ define irreducible representations of SO(3), Schur's Lemma tells us that any matrix $\mathbf{\Psi}$ commuting with all $\mathbf{D}(R)$ must be block-diagonal according to the irreps (the $l$-blocks). Furthermore, each block corresponding to an irrep $l$ must be a scalar multiple of the identity matrix within that block.
$$ (\mathbf{\Psi})_{lm, l'm'} = \delta_{ll'} \delta_{mm'} \hat{\psi}_l $$
Here, $(\mathbf{\Psi})_{lm, l'm'}$ denotes the entry of the matrix connecting input coefficient $(l', m')$ to output coefficient $(l, m)$. The map only connects coefficients with the same $l$ and $m$, and the scaling factor $\hat{\psi}_l$ depends only on the degree $l$.

**Demonstration**: Let's construct an equivariant matrix $\mathbf{\Psi}_{eq}$ with this structure and a non-equivariant matrix $\mathbf{\Psi}_{neq}$, apply them to our signal vector $\mathbf{f}$, and visualize the results.

In [None]:
# --- Construct Equivariant Linear Map Psi_eq ---
# Map depends only on degree l, acts as scalar multiplication within each l-block.
psi_hat = {0: 1.0, 1: 0.5, 2: 0.1, 3: 0.05} # Example scaling factors (filter coefficients)

Psi_eq = np.zeros((n_coeffs, n_coeffs), dtype=np.complex128)
for l in range(lmax + 1):
    scaling_factor = psi_hat.get(l, 0.0) # Default to 0 if l not in psi_hat
    if scaling_factor != 0:
        for m in range(-l, l + 1):
            idx = coeff_indices[(l, m)]
            Psi_eq[idx, idx] = scaling_factor

# --- Construct Non-Equivariant Linear Map Psi_neq ---
# Example 1: Random dense matrix (breaks all symmetries)
# Psi_neq = np.random.randn(n_coeffs, n_coeffs) + 1j * np.random.randn(n_coeffs, n_coeffs)
# Example 2: Mixes m within l=2, but doesn't mix l (partially breaks symmetry)
Psi_neq = np.zeros_like(Psi_eq)
idx_l0_start = coeff_indices[(0, 0)]
Psi_neq[idx_l0_start, idx_l0_start] = 1.0 # Keep l=0 
idx_l1_start = coeff_indices[(1, -1)]
idx_l1_end = coeff_indices[(1, 1)] + 1
Psi_neq[idx_l1_start:idx_l1_end, idx_l1_start:idx_l1_end] = 0.5 * np.eye(3) # Keep l=1 scaled
idx_l2_start = coeff_indices[(2, -2)]
idx_l2_end = coeff_indices[(2, 2)] + 1
Psi_neq[idx_l2_start:idx_l2_end, idx_l2_start:idx_l2_end] = 0.1 * (np.random.rand(5, 5) + 1j*np.random.rand(5,5)) # Mix m within l=2
idx_l3_start = coeff_indices[(3, -3)]
idx_l3_end = coeff_indices[(3, 3)] + 1
Psi_neq[idx_l3_start:idx_l3_end, idx_l3_start:idx_l3_end] = 0.05 * np.eye(7) # Keep l=3 scaled

# --- Apply Maps --- 
g_eq_vec = Psi_eq @ f_vec
g_neq_vec = Psi_neq @ f_vec

# --- Convert result vectors back to coefficient dictionaries ---
def coeffs_from_vec(vec, lmax, coeff_indices_map):
    coeffs = {}
    inv_coeff_indices = {v: k for k, v in coeff_indices_map.items()}
    for idx in range(len(vec)):
        if vec[idx] != 0:
            l, m = inv_coeff_indices[idx]
            coeffs[(l, m)] = vec[idx]
    return coeffs

g_eq_coeffs = coeffs_from_vec(g_eq_vec, lmax, coeff_indices)
g_neq_coeffs = coeffs_from_vec(g_neq_vec, lmax, coeff_indices)

# --- Visualize the results ---
g_eq_signal = generate_signal_from_coeffs(theta_grid, phi_grid, g_eq_coeffs)
g_neq_signal = generate_signal_from_coeffs(theta_grid, phi_grid, g_neq_coeffs)

# Equivariant Map Output
fig_eq = plt.figure(figsize=(8,6))
ax_eq = fig_eq.add_subplot(111, projection='3d')
norm_eq = plt.Normalize(vmin=g_eq_signal.real.min(), vmax=g_eq_signal.real.max())
surf_eq = ax_eq.plot_surface(X_cart, Y_cart, Z_cart, facecolors=cmap(norm_eq(g_eq_signal.real)), 
                           rstride=1, cstride=1, antialiased=False, shade=False)
ax_eq.set_title("Output Signal $g_{eq} = \Psi_{eq} f$ (Equivariant Map)")
ax_eq.set_xlabel('X'); ax_eq.set_ylabel('Y'); ax_eq.set_zlabel('Z')
ax_eq.set_box_aspect([1,1,1])
sm_eq = plt.cm.ScalarMappable(cmap=cmap, norm=norm_eq); sm_eq.set_array([])
cbar_eq = fig_eq.colorbar(sm_eq, ax=ax_eq, shrink=0.6, aspect=10); cbar_eq.set_label('Value of $Re(g_{eq})$')
plt.show()

# Non-Equivariant Map Output
fig_neq = plt.figure(figsize=(8,6))
ax_neq = fig_neq.add_subplot(111, projection='3d')
norm_neq = plt.Normalize(vmin=g_neq_signal.real.min(), vmax=g_neq_signal.real.max())
surf_neq = ax_neq.plot_surface(X_cart, Y_cart, Z_cart, facecolors=cmap(norm_neq(g_neq_signal.real)), 
                           rstride=1, cstride=1, antialiased=False, shade=False)
ax_neq.set_title("Output Signal $g_{neq} = \Psi_{neq} f$ (Non-Equivariant Map)")
ax_neq.set_xlabel('X'); ax_neq.set_ylabel('Y'); ax_neq.set_zlabel('Z')
ax_neq.set_box_aspect([1,1,1])
sm_neq = plt.cm.ScalarMappable(cmap=cmap, norm=norm_neq); sm_neq.set_array([])
cbar_neq = fig_neq.colorbar(sm_neq, ax=ax_neq, shrink=0.6, aspect=10); cbar_neq.set_label('Value of $Re(g_{neq})$')
plt.show()

print("Observe that the equivariant map output looks like a 'filtered' version of the input signal.")
print("The non-equivariant map output looks distorted or 'scrambled', as it mixes different m components within l=2, breaking rotational symmetry.")

## Section 7: Convolution Equivalence (Theorem 3.1)

A key result (Theorem 3.1 in Cohen et al., 2019; Aronsson, 2022) states that **any linear G-equivariant map between these function spaces can be expressed as a convolution**.

**Spherical Convolution**: The convolution $(\kappa \star f)$ of a filter $\kappa$ with a signal $f$ on the sphere is most easily defined in the harmonic domain. If $\hat{\kappa}_l$ are the harmonic coefficients of a *rotationally symmetric* filter (meaning $\kappa$ only depends on the angle from the north pole, so $\hat{\kappa}_{lm} = 0$ for $m \neq 0$), and $\hat{f}_{lm}$ are the coefficients of the signal, then the coefficients of the convolution are given by (up to normalization constants):
$$ (\widehat{\kappa \star f})_{lm} = \hat{\kappa}_l \hat{f}_{lm} $$
Notice this is exactly the operation performed by our equivariant matrix $\mathbf{\Psi}_{eq}$ from the previous section, where the filter coefficients $\hat{\kappa}_l$ correspond to the scaling factors $\hat{\psi}_l$ applied to each $l$-block.

**Equivalence Demonstration**: We will now explicitly show that the output vector $\mathbf{g}_{eq}$ obtained from the equivariant linear map $\mathbf{\Psi}_{eq}$ is identical to the output vector $\mathbf{g}_{conv}$ obtained by performing spherical convolution in the harmonic domain using $\hat{\kappa}_l = \hat{\psi}_l$.

In [None]:
# Define the kernel coefficients kappa_hat_l to be the same as psi_hat_l
kappa_hat = psi_hat 

# Compute the convolution output coefficients g_conv_lm = kappa_hat_l * f_lm
g_conv_vec = np.zeros_like(f_vec)
for l in range(lmax + 1):
    kernel_val = kappa_hat.get(l, 0.0)
    if kernel_val != 0:
        for m in range(-l, l + 1):
            idx = coeff_indices[(l, m)]
            g_conv_vec[idx] = kernel_val * f_vec[idx]

# --- Verify Equivalence Numerically ---
print(f"Comparing g_eq_vec and g_conv_vec:")
# print("g_eq_vec (first 5):", g_eq_vec[:5])
# print("g_conv_vec (first 5):", g_conv_vec[:5])
are_close = np.allclose(g_eq_vec, g_conv_vec, atol=1e-7)
print(f"Are the equivariant map output and convolution output numerically close? {are_close}")
assert are_close, "Equivariant map and convolution results should be identical!"

# --- Visualize the Convolution Result (should match g_eq) ---
g_conv_coeffs = coeffs_from_vec(g_conv_vec, lmax, coeff_indices)
g_conv_signal = generate_signal_from_coeffs(theta_grid, phi_grid, g_conv_coeffs)

fig_conv = plt.figure(figsize=(8,6))
ax_conv = fig_conv.add_subplot(111, projection='3d')
norm_conv = plt.Normalize(vmin=g_conv_signal.real.min(), vmax=g_conv_signal.real.max())
surf_conv = ax_conv.plot_surface(X_cart, Y_cart, Z_cart, facecolors=cmap(norm_conv(g_conv_signal.real)), 
                           rstride=1, cstride=1, antialiased=False, shade=False)
ax_conv.set_title("Output Signal $g_{conv} = \kappa \star f$ (Spherical Convolution)")
ax_conv.set_xlabel('X'); ax_conv.set_ylabel('Y'); ax_conv.set_zlabel('Z')
ax_conv.set_box_aspect([1,1,1])
sm_conv = plt.cm.ScalarMappable(cmap=cmap, norm=norm_conv); sm_conv.set_array([])
cbar_conv = fig_conv.colorbar(sm_conv,ax =ax_conv, shrink=0.6, aspect=10); cbar_conv.set_label('Value of $Re(g_{conv})$')
plt.show()

print("Conclusion: We have shown that the linear map Psi_eq, which respects SO(3) symmetry")
print("(due to its block-diagonal structure derived from Schur's Lemma), produces the exact")
print("same output as performing a spherical convolution with kernel coefficients kappa_hat_l = psi_hat_l.")
print("This demonstrates Theorem 3.1: any linear SO(3)-equivariant map between these function")
print("spaces (represented by harmonic coefficients) can be expressed as a convolution.")

## Section 8: Practical Implementation with e3nn

Libraries like `e3nn` are designed to handle the complexities of representation theory and automatically construct equivariant layers.

**Goal**: Show how to build an SO(3)-equivariant linear layer using `e3nn` and numerically verify its equivariance.

**Setup**:
- **Irreps**: We define the input and output feature spaces using `e3nn.o3.Irreps`. An `Irreps` object specifies the irreducible representations present (e.g., `1x0e` for one scalar (l=0, even parity), `1x1o` for one vector (l=1, odd parity), `1x2e` for one l=2 tensor).
- **Layer**: We use `e3nn.o3.Linear` which creates a linear map between the specified input and output irreps. It learns the weights (corresponding to the $\hat{\psi}_l$ factors) while ensuring the overall transformation respects the block-diagonal structure required by equivariance.
- **Data**: We create random `torch` tensors compatible with the input `Irreps`.

**Equivariance Test**: We generate a random SO(3) rotation, apply it to the input tensor using the appropriate Wigner-D matrices (handled by `Irreps.D_from_matrix`), pass both original and rotated inputs through the layer, rotate the original output, and check if the results match.

In [None]:
# --- Setup e3nn --- 
# Define Input and Output Irreducible Representations (Irreps)
# Example: Map from a scalar (0e), a vector (1o), and a tensor (2e) 
# back to the same type of features.
irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") # l=0 (scalar), l=1 (vector), l=2 (tensor)
irreps_out = o3.Irreps("1x0e + 1x1o + 1x2e") # Map to the same space for simplicity
print(f"Input Irreps: {irreps_in}, Dimension: {irreps_in.dim}")
print(f"Output Irreps: {irreps_out}, Dimension: {irreps_out.dim}")

# Create an equivariant linear layer
# e3nn automatically handles the constraints based on irreps_in and irreps_out
layer = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
print(f"\nCreated e3nn.o3.Linear layer.")

# --- Create Random Data ---
batch_size = 2
input_tensor = irreps_in.randn(batch_size, -1) # Create random data compatible with irreps_in
print(f"Shape of input tensor: {input_tensor.shape}")

# --- Equivariance Test ---
# 1. Generate a random SO(3) rotation
R_mat = o3.rand_matrix() # Generates a 3x3 random rotation matrix
print("\nTesting equivariance with a random SO(3) rotation...")

# 2. Get the corresponding Wigner-D representation matrices for input/output irreps
D_in = irreps_in.D_from_matrix(R_mat)
D_out = irreps_out.D_from_matrix(R_mat)
print(f"Shape of D_in matrix: {D_in.shape}")
print(f"Shape of D_out matrix: {D_out.shape}")

# 3. Rotate the input tensor
# Note the transpose (.T) needed for correct matrix multiplication: (Batch, Dim) @ (Dim, Dim)
input_tensor_rotated = input_tensor @ D_in.T 

# 4. Apply the layer to the original input
output_original = layer(input_tensor)

# 5. Apply the layer to the rotated input
output_rotated_input = layer(input_tensor_rotated)

# 6. Rotate the original output
output_original_rotated = output_original @ D_out.T

# 7. Check if the results are close (within numerical tolerance)
are_outputs_close = torch.allclose(output_rotated_input, output_original_rotated, atol=1e-6)
print(f"\nOutput from rotated input is close to rotated original output: {are_outputs_close}")
diff = torch.max(torch.abs(output_rotated_input - output_original_rotated))
print(f"Max absolute difference: {diff.item()}")

assert are_outputs_close, "Equivariance test failed!"

print("\nConclusion: The e3nn layer correctly respects SO(3) equivariance.")
print("The library handles the representation theory, ensuring the learned map commutes with rotations.")

## References

* Cohen, T. S., Geiger, M., & Weiler, M. (2019). *A General Theory of Equivariant CNNs on Homogeneous Spaces*. Advances in Neural Information Processing Systems (NeurIPS), 32.
* Gerken, J.E., Aronsson, J., Carlsson, O. et al. *Geometric deep learning and equivariant neural networks*. Artif Intell Rev 56, 14605–14662 (2023).

# Summary & Further Exploration
In this notebook, we explored key concepts underlying equivariant neural networks on homogeneous spaces, using the sphere $S^2 \cong SO(3)/SO(2)$ as our primary example:

1.  **Homogeneous Space Representation**: Visualized how $S^2$ arises as the orbit of a point under the $SO(3)$ action ($S^2 = SO(3)/SO(2)$).
2.  **Associated Bundles & Equivariant Maps**: Showed how tangent vector fields on $S^2$ (sections of the tangent bundle $TS^2$) can be understood through the lens of associated bundles ($SO(3) \times_{SO(2)} \mathbb{R}^2$) and $H$-equivariant maps ($H=SO(2)$).
3.  **Group Action on Fields**: Demonstrated the "push-pull" action of $SO(3)$ on vector fields, illustrating the core idea of equivariance for feature maps.
4.  **Harmonic Analysis**: Introduced spherical harmonics $Y_l^m$ as an equivariant basis for functions on $S^2$, forming irreducible representations of $SO(3)$.
5.  **Harmonic Representation**: Showed how signals on the sphere can be represented by a vector of their harmonic coefficients $\hat{f}_{lm}$.
6.  **Linear Equivariant Maps & Schur's Lemma**: Demonstrated that linear SO(3)-equivariant maps acting on harmonic coefficients must be block-diagonal, with each block being a scalar multiple of the identity $(\mathbf{\Psi})_{lm, l'm'} = \delta_{ll'} \delta_{mm'} \hat{\psi}_l$.
7.  **Convolution Equivalence**: Explicitly verified Theorem 3.1 (Cohen et al., 2019), showing that such a linear equivariant map is equivalent to performing spherical convolution in the harmonic domain with filter coefficients $\hat{\kappa}_l = \hat{\psi}_l$.
8.  **Practical Implementation**: Used the `e3nn` library to construct an SO(3)-equivariant linear layer and numerically verified its equivariance property, demonstrating how libraries abstract the underlying theory.

This provides a clear path from the geometric foundations of homogeneous spaces and group actions to the practical construction of equivariant neural networks, highlighting the crucial role of representation theory (spherical harmonics, Schur's lemma) in linking symmetry constraints to the form of learnable layers (convolution).

### Next Steps:
- Explore other homogeneous spaces (e.g., hyperbolic space $SL(2, \mathbb{R})/SO(2)$) and the corresponding group actions and harmonic bases.
- Implement a more complete spherical convolution using libraries like `e3nn`'s `o3.Convolution`.
- Investigate different types of feature maps (e.g., higher-order tensors) and their corresponding associated bundles and representations (`Irreps` in `e3nn`).
- Study different G-CNN architectures built using these equivariant layers for specific applications (e.g., point clouds, molecular data, physics simulations).

![QF-Mission](https://quantumformalism.academy/img/qf-down.png)

**Copyright © 2025 Quantum Formalism Academy. All rights reserved.**

This notebook is a product of **Quantum Formalism Academy** and is intended for educational purposes. Redistribution, modification, or commercial use of this material without prior written permission from Quantum Formalism is prohibited.