# Computer-Assisted Proof: Lie Algebra Generation for Lorenz 96 Model

**Context:** This notebook implements the computer-assisted proof (CAP) detailed in Appendix C of the paper *"Non-uniqueness of stationary measures for stochastic systems with almost surely invariant manifolds"* by J. Bedrossian, A. Blumenthal, and S. Punshon-Smith.

**Objective:** The primary goal is to verify a key algebraic property needed for the theoretical results in the paper. Specifically, we aim to show that the Lie algebra generated by the matrices $M_k = DB(e_k)|_{H_I^\perp}$ (representing the linearization of the Lorenz 96 dynamics restricted to the transverse space $H_I^\perp$) is the special linear Lie algebra $\mathfrak{sl}(H_I^\perp)$. This verification is crucial for establishing the hypoellipticity of related stochastic processes discussed in the paper.

**Methodology:**
1.  **Matrix Representation:** Define the matrices $M_k$ using Sympy for exact rational arithmetic based on the L96 equations.
2.  **Lie Brackets:** Compute iterated Lie brackets $[M_k, M_l]$ up to a specified depth (depth 5 is used here).
3.  **Span Verification:** Use symbolic linear algebra (specifically, Reduced Row Echelon Form - RREF) on the vectorized forms of the generated matrices (original $M_k$ and their brackets) to determine the basis of the spanned space.
4.  **Elementary Matrix Check:** Identify if specific elementary matrices, which form a known generating set for $\mathfrak{sl}(H_I^\perp)$ when combined with shift-invariance properties (see Lemma C.1 and C.3 in the paper), are present in the span.

**Implementation Details:**
*   The calculations are performed for $N=9$ (transverse dimension $2K=6$) to ensure the core algebraic structure holds. The paper's appendix notes that due to the local nature of brackets and shift invariance, the result extends to larger $N=3K, K \ge 3$.
*   Sympy is used for symbolic computation and exact arithmetic to avoid numerical errors.


In [None]:
%pip install sympy numpy matplotlib > /dev/null 2>&1

In [None]:
import sympy as sp
from sympy import Matrix, zeros, Rational
import numpy as np
from IPython.display import display, Math

sp.init_printing()  # Use exact arithmetic



## Matrix Generation

First, we implement the L96 H matrices using exact rational arithmetic,

In [None]:

def discrete_delta(i, j):
    """2D delta function that returns 1 only when both inputs are 0."""
    return Rational(1 if i == 0 and j == 0 else 0)

def B_coefficient(k, j, l, n):
    """Matrix coefficient generator using modular arithmetic."""
    j_k_plus_1 = (j - k + 1) % n
    l_k_plus_2 = (l - k + 2) % n
    j_k_plus_2 = (j - k + 2) % n
    l_k_plus_1 = (l - k + 1) % n
    j_k_minus_1 = (j - k - 1) % n
    l_k_minus_2 = (l - k - 2) % n
    
    return (discrete_delta(j_k_plus_1, l_k_plus_2) -
            discrete_delta(j_k_plus_2, l_k_plus_1) +
            discrete_delta(j_k_minus_1, l_k_minus_2) -
            discrete_delta(j_k_minus_1, l_k_plus_2))

def get_transverse_indices(n):
    """Get indices not divisible by 3."""
    return [i for i in range(n) if i % 3 != 0]

def generate_h_matrix(k, trans_index, n):
    """Generate H matrix for index k."""
    size = len(trans_index)
    H = zeros(size, size)
    
    for i, j in enumerate(trans_index):
        for m, l in enumerate(trans_index):
            H[i, m] = B_coefficient(k, j, l, n)
            
    return H

def compute_bracket(A, B):
    """Compute [A,B] using rational arithmetic."""
    return A * B - B * A

# Generate matrices
n=9 # System size big enough to avoid the boundary
depth = 5 # Maximum depth for bracket computation
trans_index = get_transverse_indices(n)
h_indices = [3, 6, 9] # Indices for H matrices

if n % 3 != 0:
    raise ValueError("n must be a multiple of 3")

if not all(k % 3 == 0 and k <= n for k in h_indices):
    raise ValueError("All h_indices must be multiples of 3 and less than or equal to n")

print(f"System size n = {n}")
print(f"Matrix size = {len(trans_index)}×{len(trans_index)}")

print(f"Generating H matrices: {', '.join([f'H_{k}' for k in h_indices])}")
h_matrices = [generate_h_matrix(k, trans_index, n) for k in h_indices]
# Print generated H_k matrices
for idx, H in enumerate(h_matrices):
    display(Math(f"H_{{{h_indices[idx]}}} ="))
    display(H)



System size n = 9
Matrix size = 6×6
Generating H matrices: H_3, H_6, H_9


<IPython.core.display.Math object>

⎡0   -1  0  0  0  0⎤
⎢                  ⎥
⎢1   0   0  0  0  0⎥
⎢                  ⎥
⎢-1  0   0  1  0  0⎥
⎢                  ⎥
⎢0   0   0  0  0  0⎥
⎢                  ⎥
⎢0   0   0  0  0  0⎥
⎢                  ⎥
⎣0   0   0  0  0  0⎦

<IPython.core.display.Math object>

⎡0  0  0   0   0  0⎤
⎢                  ⎥
⎢0  0  0   0   0  0⎥
⎢                  ⎥
⎢0  0  0   -1  0  0⎥
⎢                  ⎥
⎢0  0  1   0   0  0⎥
⎢                  ⎥
⎢0  0  -1  0   0  1⎥
⎢                  ⎥
⎣0  0  0   0   0  0⎦

<IPython.core.display.Math object>

⎡0  1  0  0  -1  0 ⎤
⎢                  ⎥
⎢0  0  0  0  0   0 ⎥
⎢                  ⎥
⎢0  0  0  0  0   0 ⎥
⎢                  ⎥
⎢0  0  0  0  0   0 ⎥
⎢                  ⎥
⎢0  0  0  0  0   -1⎥
⎢                  ⎥
⎣0  0  0  0  1   0 ⎦

## Generate and Compute Brackets

We compute all brackets up to depth 5 by default

In [9]:
def compute_all_brackets(h_matrices, max_depth=5):
    """Compute all brackets up to given depth."""
    current = h_matrices.copy()
    all_matrices = current.copy()
    ops = [f"H_{i}" for i in range(len(h_matrices))]
    all_ops = ops.copy()
    
    print(f"Computing brackets up to depth {max_depth}...")
    
    for depth in range(2, max_depth + 1):
        print(f"\nDepth {depth}:")
        new_matrices = []
        new_ops = []
        
        # Compute new brackets
        for i, H in enumerate(h_matrices):
            for j, M in enumerate(current):
                bracket = compute_bracket(H, M)
                if not bracket.is_zero_matrix:
                    new_matrices.append(bracket)
                    new_ops.append(f"[H_{3*(i+1)}, {ops[j]}]")
        
        if not new_matrices:
            print("No new brackets found")
            break
            
        print(f"Found {len(new_matrices)} new brackets")
        
        # Add new matrices
        all_matrices.extend(new_matrices)
        all_ops.extend(new_ops)
        
        # Update for next iteration
        current = new_matrices
        ops = new_ops
    
    return all_matrices, all_ops



# Compute all brackets
all_matrices, all_ops = compute_all_brackets(h_matrices)
print(f"\nTotal matrices: {len(all_matrices)}")

Computing brackets up to depth 5...

Depth 2:
Found 6 new brackets

Depth 3:
Found 18 new brackets

Depth 4:
Found 54 new brackets

Depth 5:
Found 162 new brackets

Total matrices: 243


## Find Elementary Matrices

We use row reduction to find linear combinations that yield elementary matrices:

In [None]:
def find_elementary_matrices(matrices, ops):
    """Find elementary matrices using row reduction."""
    matrix_size = matrices[0].rows
    
    # Stack matrices as vectors
    stack = zeros(len(matrices), matrix_size * matrix_size)
    for i, M in enumerate(matrices):
        for r in range(matrix_size):
            for c in range(matrix_size):
                stack[i, r*matrix_size + c] = M[r,c]
    
    print(f"Matrix stack shape: {stack.rows} × {stack.cols}")
    
    # Compute RREF
    print("\nComputing row reduction...")
    rref, pivots = stack.rref()
    
    # Check each row for elementary matrices
    elementary = []
    print("\nChecking for elementary matrices:")
    
    for i in range(rref.rows):
        # Convert row back to matrix
        M = zeros(matrix_size, matrix_size)
        for r in range(matrix_size):
            for c in range(matrix_size):
                M[r,c] = rref[i, r*matrix_size + c]
        
        # Count non-zero entries
        non_zero = [(r, c, M[r,c]) for r in range(matrix_size)
                    for c in range(matrix_size) if M[r,c] != 0]
        
        if len(non_zero) == 1:
            r, c, val = non_zero[0]
            elementary.append((r, c, val))
            print(f"Found E_{{{r+1},{c+1}}}")
    
    return elementary

# Find elementary matrices
elementary_matrices = find_elementary_matrices(all_matrices, all_ops)

# Check if we found our targets,
# i.e. E_{3,2}, E_{4,3}, E_{5,4}
# Note that the indices are 0-based in the code, so we need to add 1
# e.g. (2,1) corresponds to E_{3,2} etc.
targets = [(2,1), (3,2), (4,3)]
print("\nChecking target matrices:")
for i, j in targets:
    found = any(r == i and c == j for r, c, _ in elementary_matrices)
    print(f"{'✓' if found else '✗'} E_{{{i+1},{j+1}}}")

Matrix stack shape: 243 × 36

Computing row reduction...

Checking for elementary matrices:
Found E_{1,2}
Found E_{1,3}
Found E_{1,4}
Found E_{1,5}
Found E_{1,6}
Found E_{2,1}
Found E_{2,3}
Found E_{2,4}
Found E_{2,5}
Found E_{2,6}
Found E_{3,1}
Found E_{3,2}
Found E_{3,4}
Found E_{3,5}
Found E_{3,6}
Found E_{4,1}
Found E_{4,2}
Found E_{4,3}
Found E_{4,5}
Found E_{4,6}
Found E_{5,1}
Found E_{5,2}
Found E_{5,3}
Found E_{5,4}
Found E_{5,6}
Found E_{6,1}
Found E_{6,2}
Found E_{6,3}
Found E_{6,4}
Found E_{6,5}

Checking target matrices:
✓ E_{3,2}
✓ E_{4,3}
✓ E_{5,4}
