In [1]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

$(yx^\top) \odot A = D_y A D_x^\top?$

- okay this works!!!!

In [2]:
A = jnp.array([
    [2.0, -1.0, 0.5],
    [-1.0, 3.0, 1.5],
    [0.5, 1.5, 4.0]
])

In [3]:
x = jnp.array([4, 5, 2])
x[..., None] @ x[..., None].T

Array([[16, 20,  8],
       [20, 25, 10],
       [ 8, 10,  4]], dtype=int64)

In [4]:
(x[..., None] @ x[..., None].T) * A

Array([[ 32., -20.,   4.],
       [-20.,  75.,  15.],
       [  4.,  15.,  16.]], dtype=float64)

In [5]:
jnp.diag(x) @ A @ jnp.diag(x).T

Array([[ 32., -20.,   4.],
       [-20.,  75.,  15.],
       [  4.,  15.,  16.]], dtype=float64)

In [6]:
y = jnp.array([3, 8, 1])

In [7]:
(x[..., None] @ x[..., None].T + y[..., None] @ y[..., None].T) * A

Array([[ 50. , -44. ,   5.5],
       [-44. , 267. ,  27. ],
       [  5.5,  27. ,  20. ]], dtype=float64)

In [8]:
jnp.diag(x) @ A @ jnp.diag(x).T + jnp.diag(y) @ A @ jnp.diag(y).T

Array([[ 50. , -44. ,   5.5],
       [-44. , 267. ,  27. ],
       [  5.5,  27. ,  20. ]], dtype=float64)

In [9]:
jnp.diag(x), A

(Array([[4, 0, 0],
        [0, 5, 0],
        [0, 0, 2]], dtype=int64),
 Array([[ 2. , -1. ,  0.5],
        [-1. ,  3. ,  1.5],
        [ 0.5,  1.5,  4. ]], dtype=float64))

In [10]:
jnp.diag(x) @ A @ jnp.diag(x), x[..., None] @ x[..., None].T * A

(Array([[ 32., -20.,   4.],
        [-20.,  75.,  15.],
        [  4.,  15.,  16.]], dtype=float64),
 Array([[ 32., -20.,   4.],
        [-20.,  75.,  15.],
        [  4.,  15.,  16.]], dtype=float64))

simplifying inverses of a sum

$$(A+B)^{-1} = A^{-1} - A^{-1} (I + B A^{-1})^{-1} B A^{-1}$$

Or Hua's Identity

$$(A+B)^{-1} = A^{-1} - (A^{-1} + A B^{-1} A)^{-1}$$

For us, we are interested in 

$$(D_{W_0} V D_{W_0} + D_{W_1} V D_{W_1})^{-1}$$ 

Note that $A^{-1} = D_{W_0}^{-1} V^{-1} D_{W_0}^{-1}$ and $B = D_{W_1} V D_{W_1}$

Using Hua's Identity
$$(A+B)^{-1} = D_{W_0}^{-1} V^{-1} D_{W_0}^{-1} - (D_{W_0}^{-1} V^{-1} D_{W_0}^{-1} + D_{W_0} V D_{W_0} D_{W_1}^{-1} V^{-1} D_{W_1}^{-1} D_{W_0} V D_{W_0})^{-1}$$

In [13]:
T = 120
K = 5 
L = 36
key = jax.random.key(12)

F = jax.random.randint(key=key, shape=(T, K), minval=1, maxval=5)
W = jax.random.randint(key=key, shape=(L, K), minval=1, maxval=3)

A.shape, W.shape

((3, 3), (36, 5))

In [14]:
cov = jnp.cov(F.T)
cov.shape

(5, 5)

In [15]:
L, Q = jnp.linalg.eigh(cov)

In [17]:
X = jnp.diag(W[0]) @ Q
X[:, 0]

Array([ 1.22080624, -0.51296942, -0.34413815,  0.55836835,  0.7250306 ],      dtype=float64)

In [18]:
Q, jnp.diag(W[0])

(Array([[ 0.61040312, -0.39172893,  0.44928992,  0.20488956,  0.47970336],
        [-0.25648471,  0.51133311, -0.00220414, -0.14396639,  0.80747933],
        [-0.34413815,  0.18772717,  0.89013769, -0.02765554, -0.23068921],
        [ 0.55836835,  0.3567835 ,  0.0721401 , -0.72403853, -0.17746644],
        [ 0.3625153 ,  0.65003918, -0.02417472,  0.64210383, -0.18207133]],      dtype=float64),
 Array([[2, 0, 0, 0, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 2]], dtype=int64))

In [19]:
(Q.T @ jnp.diag(W[0])[:, 0])

Array([ 1.22080624, -0.78345786,  0.89857983,  0.40977913,  0.95940672],      dtype=float64)

In [20]:
(Q.T @ jnp.diag(W[0]))

Array([[ 1.22080624, -0.51296942, -0.34413815,  0.55836835,  0.7250306 ],
       [-0.78345786,  1.02266622,  0.18772717,  0.3567835 ,  1.30007837],
       [ 0.89857983, -0.00440827,  0.89013769,  0.0721401 , -0.04834944],
       [ 0.40977913, -0.28793279, -0.02765554, -0.72403853,  1.28420766],
       [ 0.95940672,  1.61495866, -0.23068921, -0.17746644, -0.36414267]],      dtype=float64)

In [21]:
jnp.diag(W[0]) @ Q

Array([[ 1.22080624, -0.78345786,  0.89857983,  0.40977913,  0.95940672],
       [-0.51296942,  1.02266622, -0.00440827, -0.28793279,  1.61495866],
       [-0.34413815,  0.18772717,  0.89013769, -0.02765554, -0.23068921],
       [ 0.55836835,  0.3567835 ,  0.0721401 , -0.72403853, -0.17746644],
       [ 0.7250306 ,  1.30007837, -0.04834944,  1.28420766, -0.36414267]],      dtype=float64)

In [25]:
curr = jnp.zeros((K, K))
for i in range(5):
    curr += W[i][..., None] @ W[i][..., None].T
curr * cov

Array([[19.42738095,  2.13298319,  0.34621849, -1.9697479 , -0.54530812],
       [ 2.13298319, 10.97983193, -0.8       ,  0.05378151, -0.68634454],
       [ 0.34621849, -0.8       , 12.7070028 ,  0.73389356,  0.34537815],
       [-1.9697479 ,  0.05378151,  0.73389356, 13.48347339, -1.77478992],
       [-0.54530812, -0.68634454,  0.34537815, -1.77478992, 13.49502801]],      dtype=float64)

In [26]:
curr2 = jnp.zeros((K, K))
for i in range(5):
    curr2 += jnp.diag(W[i]) @ cov @ jnp.diag(W[i])
curr2

Array([[19.42738095,  2.13298319,  0.34621849, -1.9697479 , -0.54530812],
       [ 2.13298319, 10.97983193, -0.8       ,  0.05378151, -0.68634454],
       [ 0.34621849, -0.8       , 12.7070028 ,  0.73389356,  0.34537815],
       [-1.9697479 ,  0.05378151,  0.73389356, 13.48347339, -1.77478992],
       [-0.54530812, -0.68634454,  0.34537815, -1.77478992, 13.49502801]],      dtype=float64)

what is below

In [21]:
A = jnp.diag(W[0]) @ cov @ jnp.diag(W[0]).T
B = jnp.diag(W[1]) @ cov @ jnp.diag(W[1]).T

In [14]:
jnp.isnan(A).any(), jnp.isnan(B).any(), jnp.isnan(cov).any()

(Array(False, dtype=bool), Array(False, dtype=bool), Array(False, dtype=bool))

In [15]:
one = jnp.linalg.inv(A + B)

In [16]:
inv_A = jnp.linalg.inv(A)
two = inv_A - inv_A @ jnp.linalg.inv(jnp.eye(K) + B @ inv_A) @ B @ inv_A

In [17]:
jnp.allclose(one, two)

Array(True, dtype=bool)

In [18]:
# Hua's identity
three = inv_A - jnp.linalg.inv(A + A @ jnp.linalg.inv(B) @ A)


In [19]:
jnp.allclose(one, three)

Array(True, dtype=bool)

sums of PSD

In [None]:
import jax.numpy as jnp
from jax import random
import numpy as np

def explicit_form(A, B, X):
    """Compute result using element-wise explicit formula"""
    n = A.shape[0]
    result = jnp.zeros((n, n))
    
    for i in range(n):
        for j in range(n):
            # Sum_k x_k * (a_ki * a_kj + b_ki * b_kj)
            element = 0
            for k in range(n):
                element += X[k,k] * (A[k,i] * A[k,j] + B[k,i] * B[k,j])
            result = result.at[i,j].set(element)
    
    return result

def matrix_form(A, B, X):
    """Compute using matrix multiplication A^T X A + B^T X B"""
    return A.T @ X @ A + B.T @ X @ B

# Test cases
def run_tests():
    key = random.PRNGKey(0)
    
    # Test case 1: Small matrices (3x3)
    n = 3
    key, subkey1, subkey2 = random.split(key, 3)
    A = random.normal(subkey1, (n, n))
    B = random.normal(subkey2, (n, n))
    X = jnp.diag(random.uniform(key, (n,)))  # diagonal matrix
    
    result_explicit = explicit_form(A, B, X)
    result_matrix = matrix_form(A, B, X)
    
    print("Test case 1 (3x3 matrices):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))
    
    # Test case 2: Larger matrices (5x5)
    n = 5
    key, subkey1, subkey2 = random.split(key, 3)
    A = random.normal(subkey1, (n, n))
    B = random.normal(subkey2, (n, n))
    X = jnp.diag(random.uniform(key, (n,)))
    
    result_explicit = explicit_form(A, B, X)
    result_matrix = matrix_form(A, B, X)
    
    print("\nTest case 2 (5x5 matrices):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))
    
    # Test case 3: Special case with identity matrix for X
    X = jnp.eye(n)
    
    result_explicit = explicit_form(A, B, X)
    result_matrix = matrix_form(A, B, X)
    
    print("\nTest case 3 (X = Identity):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))

run_tests()

Test case 1 (3x3 matrices):
Maximum difference: 9.536743e-07
Are results equal? True

Test case 2 (5x5 matrices):
Maximum difference: 4.7683716e-07
Are results equal? True

Test case 3 (X = Identity):
Maximum difference: 9.536743e-07
Are results equal? True


In [3]:
import jax.numpy as jnp
from jax import random
import numpy as np
from typing import List

def explicit_form_n_matrices(matrices: List[jnp.ndarray], X: jnp.ndarray):
    """Compute result using element-wise explicit formula for n matrices"""
    n = matrices[0].shape[0]
    result = jnp.zeros((n, n))
    
    for i in range(n):
        for j in range(n):
            # Sum_k x_k * (sum_m m_ki * m_kj) where m is each matrix
            element = 0
            for k in range(n):
                term = 0
                for matrix in matrices:
                    term += matrix[k,i] * matrix[k,j]
                element += X[k,k] * term
            result = result.at[i,j].set(element)
    
    return result

def matrix_form_n_matrices(matrices: List[jnp.ndarray], X: jnp.ndarray):
    """Compute using matrix multiplication sum(M^T X M) for all matrices M"""
    result = jnp.zeros_like(matrices[0])
    for matrix in matrices:
        result += matrix.T @ X @ matrix
    return result

def run_extended_tests():
    key = random.PRNGKey(0)
    
    # Test case 1: Original two matrices
    n = 3
    key, subkey1, subkey2 = random.split(key, 3)
    A = random.normal(subkey1, (n, n))
    B = random.normal(subkey2, (n, n))
    X = jnp.diag(random.uniform(key, (n,)))
    
    matrices = [A, B]
    result_explicit = explicit_form_n_matrices(matrices, X)
    result_matrix = matrix_form_n_matrices(matrices, X)
    
    print("Test case 1 (Two matrices):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))
    
    # Test case 2: Three matrices
    key, subkey1, subkey2, subkey3 = random.split(key, 4)
    A = random.normal(subkey1, (n, n))
    B = random.normal(subkey2, (n, n))
    C = random.normal(subkey3, (n, n))
    X = jnp.diag(random.uniform(key, (n,)))
    
    matrices = [A, B, C]
    result_explicit = explicit_form_n_matrices(matrices, X)
    result_matrix = matrix_form_n_matrices(matrices, X)
    
    print("\nTest case 2 (Three matrices):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))
    
    # Test case 3: Five matrices
    n = 4  # slightly larger matrices
    key, *subkeys = random.split(key, 6)
    matrices = [random.normal(subkey, (n, n)) for subkey in subkeys]
    X = jnp.diag(random.uniform(key, (n,)))
    
    result_explicit = explicit_form_n_matrices(matrices, X)
    result_matrix = matrix_form_n_matrices(matrices, X)
    
    print("\nTest case 3 (Five matrices):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))
    
    # Test case 4: Special case with identity matrix for X
    X = jnp.eye(n)
    
    result_explicit = explicit_form_n_matrices(matrices, X)
    result_matrix = matrix_form_n_matrices(matrices, X)
    
    print("\nTest case 4 (Five matrices, X = Identity):")
    print("Maximum difference:", jnp.max(jnp.abs(result_explicit - result_matrix)))
    print("Are results equal?", jnp.allclose(result_explicit, result_matrix))

if __name__ == "__main__":
    run_extended_tests()

Test case 1 (Two matrices):
Maximum difference: 2.3841858e-07
Are results equal? True

Test case 2 (Three matrices):
Maximum difference: 2.3841858e-07
Are results equal? True

Test case 3 (Five matrices):
Maximum difference: 1.9073486e-06
Are results equal? True

Test case 4 (Five matrices, X = Identity):
Maximum difference: 1.9073486e-06
Are results equal? True
