In [1]:
import sympy as sp

# Define the dimensions
n = 2  # Example dimension for matrices

# Define real matrix symbols
U = sp.MatrixSymbol('U', n, n)
A = sp.MatrixSymbol('A', n, n)

# Convert MatrixSymbols to actual Matrices for operations
U = sp.Matrix(U, real=True)
A = sp.Matrix(A, real=True, symmetric=True)

# Perform matrix multiplication
UAU_T = U * A * U.T

# Extract diagonal elements
diag_UAU_T = sp.Matrix([UAU_T[i, i] for i in range(n)])

# Calculate Frobenius norm squared of the diagonal matrix
Err_U_A = diag_UAU_T.norm(2)**2

# Simplify the expression (optional)
Err_U_A_simplified = sp.simplify(Err_U_A)

# Compute the gradient of the error function with respect to each element of U
gradient_matrix = sp.Matrix(n, n, lambda i, j: Err_U_A.diff(U[i, j]))

print(gradient_matrix)

Matrix([[2*(((re(A[0, 0]*U[1, 0]) + re(A[1, 0]*U[1, 1]))*re(U[1, 0]) + (re(A[0, 1]*U[1, 0]) + re(A[1, 1]*U[1, 1]))*re(U[1, 1]) - (im(A[0, 0]*U[1, 0]) + im(A[1, 0]*U[1, 1]))*im(U[1, 0]) - (im(A[0, 1]*U[1, 0]) + im(A[1, 1]*U[1, 1]))*im(U[1, 1]))*((re(A[0, 0]*U[1, 0]) + re(A[1, 0]*U[1, 1]))*Derivative(re(U[1, 0]), U[0, 0]) + (re(A[0, 1]*U[1, 0]) + re(A[1, 1]*U[1, 1]))*Derivative(re(U[1, 1]), U[0, 0]) + (-im(A[0, 0]*U[1, 0]) - im(A[1, 0]*U[1, 1]))*Derivative(im(U[1, 0]), U[0, 0]) + (-im(A[0, 1]*U[1, 0]) - im(A[1, 1]*U[1, 1]))*Derivative(im(U[1, 1]), U[0, 0]) + (Derivative(re(A[0, 0]*U[1, 0]), U[0, 0]) + Derivative(re(A[1, 0]*U[1, 1]), U[0, 0]))*re(U[1, 0]) + (Derivative(re(A[0, 1]*U[1, 0]), U[0, 0]) + Derivative(re(A[1, 1]*U[1, 1]), U[0, 0]))*re(U[1, 1]) + (-Derivative(im(A[0, 0]*U[1, 0]), U[0, 0]) - Derivative(im(A[1, 0]*U[1, 1]), U[0, 0]))*im(U[1, 0]) + (-Derivative(im(A[0, 1]*U[1, 0]), U[0, 0]) - Derivative(im(A[1, 1]*U[1, 1]), U[0, 0]))*im(U[1, 1])) + ((re(A[0, 0]*U[1, 0]) + re(A[1, 0]

In [2]:
import numpy as np
n = 10
A = np.random.rand(n, n)
A = A @ A.T  # Make A symmetric positive definite

A_bk = A.copy()
print(A)
U = np.eye(n)

def ddiag(U, A):
    return np.linalg.norm(np.diag(U @ A @ U.T), ord=2)**2

print(ddiag(U, A))
print(A[0, 0]**2 + A[1, 1]**2 + A[2, 2]**2)

[[5.58153878 4.31269789 4.49011915 3.11223889 5.18888425 3.0673968
  2.99629604 3.52768532 3.52952311 3.60158616]
 [4.31269789 3.85561166 3.94656248 2.57155775 4.51936317 2.80794463
  2.48319523 2.65296933 2.93948563 3.1561755 ]
 [4.49011915 3.94656248 4.44620456 2.70741254 4.74597833 3.27606023
  2.79871468 3.16008534 3.2276456  3.40239203]
 [3.11223889 2.57155775 2.70741254 2.90750236 3.48660752 2.05317224
  2.74294468 2.19629001 2.69628908 2.60027145]
 [5.18888425 4.51936317 4.74597833 3.48660752 5.99771582 3.64777695
  3.4559984  3.52532724 3.9517924  4.07674025]
 [3.0673968  2.80794463 3.27606023 2.05317224 3.64777695 3.10588431
  2.45942048 2.31118833 2.67618365 2.34283798]
 [2.99629604 2.48319523 2.79871468 2.74294468 3.4559984  2.45942048
  3.19708447 2.41062455 3.10885656 2.38049571]
 [3.52768532 2.65296933 3.16008534 2.19629001 3.52532724 2.31118833
  2.41062455 3.59658346 2.83591242 2.26393606]
 [3.52952311 2.93948563 3.2276456  2.69628908 3.9517924  2.67618365
  3.10885656 

In [3]:
from scipy.linalg import polar

def diag_norm(A):
    return np.linalg.norm(np.diag(A), ord=2)

def riemannian_diagonalization(A, U = None, h = 1e-2, iterations = 100, out = False):
    n = A.shape[1]
    if U == None:
        U = np.eye(n)

    A = A.copy()
    Utot = U.copy()

    d1 = diag_norm(A)
    it = 0
    finish = False
    while not finish:
        # x' = 1 + h * df/dx
        U = np.eye(n)
        for p in range(n):
            for q in range(n):
                U[p, q] -= h * A[p, p] * (A[p, q] + A[q, p])

        # projection to unitary space
        U, _ = polar(U, side = 'left')

        # A' = U^T A U
        A2 = U.T @ A @ U
        d2 = diag_norm(A2)
        if out:
            print(it, h, d1, d2)

        if d2 > d1: # accept step
            A = A2.copy()
            h *= 1.5
            Utot = Utot @ U
            d1 = d2
            it += 1
            if it >= iterations:
                finish = True
        else: # reject step
            h /= 2
            if h < 1e-16:
                finish = True

    # make sure numerical errors are corrected
    Utot, _ = polar(Utot, side = 'left')
    return np.diagonal(A), Utot

d, Utot = riemannian_diagonalization(A)

A2 = Utot.T @ A @ Utot
A2[np.abs(A2) < 1e-10] = 0

print(sorted(d))
A = A_bk.copy()
d2, U2 = np.linalg.eigh(A)
print(d2)
print(np.linalg.norm(d2 - sorted(d)))
print(Utot.T @ A @ Utot)


[np.float64(0.0035153491067508257), np.float64(0.04470574164002243), np.float64(0.1418499749732432), np.float64(0.22900647231138657), np.float64(0.3661057582399805), np.float64(0.5206858463334868), np.float64(1.162927418465345), np.float64(1.434538924396561), np.float64(2.0831292580028347), np.float64(33.162273289800126)]
[3.51534911e-03 4.47057416e-02 1.41849975e-01 2.29006472e-01
 3.66105758e-01 5.20685846e-01 1.16292742e+00 1.43453892e+00
 2.08312926e+00 3.31622733e+01]
4.553107438230389e-13
[[ 1.43453892e+00  2.05160498e-15  6.71038977e-16  5.58936420e-16
  -9.36486882e-16 -1.58178833e-17 -3.52930210e-15 -4.17167312e-15
  -4.33195905e-17  1.15545508e-15]
 [ 2.15212003e-15  1.41849975e-01  4.76863651e-16 -1.10843043e-17
   4.84028039e-14  2.35316773e-17 -2.66170534e-15 -6.36998188e-15
   3.97677427e-16 -1.38129243e-15]
 [ 9.43072217e-16  2.68972977e-16  3.66105758e-01  2.52898753e-15
  -1.09635539e-14 -3.64707832e-16 -2.74829586e-15  1.24035903e-14
   7.94158121e-16 -3.20862938e-16]

In [4]:
# Simultaneous diagonalization
from scipy.linalg import polar

def diag_norm_k(A):
    d = 0
    for i in range(A.shape[0]):
        d += np.linalg.norm(np.diag(A[i, :, :]), ord = 2)**2
    return d

A = np.random.rand(3, 3, 3)

def transform_matrices(A, U):
    A = A.copy()
    for k in range(A.shape[0]):
        A[k, :, :] = U.T @ A[k, :, :] @ U
    return A

def riemannian_joint_diagonalization(A, U = None, h = 1e-2, iterations = 100, out = False):
    if U == None:
        U = np.eye(A.shape[1])

    A = A.copy()
    Utot = U.copy()
    n = A.shape[1]

    d1 = diag_norm_k(A)
    it = 0
    finish = False
    while not finish:
        # x' = 1 + h * df/dx
        U = np.eye(n)
        for p in range(n):
            for q in range(n):
                for k in range(A.shape[0]):
                    U[p, q] -= h * A[k, p, p] * (A[k, p, q] + A[k, q, p])

        # projection to unitary space
        U, _ = polar(U, side = 'left')

        # A' = U^T A U
        A2 = transform_matrices(A, U)
        d2 = diag_norm_k(A2)
        if out:
            print(it, h, d1, d2)

        if d2 > d1: # accept step
            A = A2.copy()
            h *= 1.5
            Utot = Utot @ U
            d1 = d2
            it += 1
            if it >= iterations:
                finish = True
        else: # reject step
            h /= 2
            if h < 1e-16:
                finish = True

    # make sure numerical errors are corrected
    Utot, _ = polar(Utot, side = 'left')
    d = np.zeros([A.shape[1], A.shape[2]])
    for k in range(A.shape[0]):
        d[k, :] = np.diagonal(A[k, :, :])
    return d, Utot

d, Utot = riemannian_joint_diagonalization(A)
print(d)
A2 = transform_matrices(A, Utot)
print(A2)

[[-0.01684373  1.54857056 -0.1148937 ]
 [ 0.59748363  1.5484856  -0.31583259]
 [ 0.37885789  0.98213691 -0.14411092]]
[[[-0.01684373  0.09546281  0.10073941]
  [-0.48786611  1.54857056  0.06984208]
  [ 0.13005235  0.02946492 -0.1148937 ]]

 [[ 0.59748363 -0.08714678 -0.14585769]
  [ 0.49407233  1.5484856  -0.19889301]
  [ 0.16146075 -0.13806576 -0.31583259]]

 [[ 0.37885789  0.21238217 -0.17296292]
  [ 0.16436888  0.98213691 -0.1099407 ]
  [ 0.10237953  0.52104566 -0.14411092]]]


In [5]:
# Weighted Simultaneous diagonalization
from scipy.linalg import polar
import numpy as np

def weighted_diagonal_norm(A, W):
    d = 0
    for k in range(A.shape[0]):
        for i in range(A.shape[1]):
            d += A[k, i, i]**2 / W[i, i]
    return d

def transform_matrices_w(A, W, U):
    A = A.copy()
    for k in range(A.shape[0]):
        A[k, :, :] = U.T @ A[k, :, :] @ U
    W = U.T @ W @ U
    return A, W

def weighted_diagonals(A, W):
    d = np.zeros([A.shape[1], A.shape[2]])
    for k in range(A.shape[0]):
        d[k, :] = np.diagonal(A[k, :, :]) / np.diagonal(W)
    return d

def weighted_symmetric(A, W):
    for k in range(A.shape[0]):
        A[k, :, :] = 0.5 * (A[k, :, :] @ W + W @ A[k, :, :])
    return A

def riemannian_joint_diagonalization(A, W, U = None, h = 1e-3, iterations = 1000, out = False):
    if U == None:
        U = np.eye(A.shape[1])

    A = A.copy()
    Utot = U.copy()
    n = A.shape[1]

    d1 = weighted_diagonal_norm(A, W)
    d0 = d1
    it = 0
    finish = False
    while not finish:
        # x' = 1 + h * df/dx
        G = np.zeros([n, n])
        G2 = np.zeros([n, n])
        for p in range(n):
            for q in range(n):
                for k in range(A.shape[0]):
                    G[p, q] += 2 * A[k, q, q] / W[q, q] * A[k, q, p] - A[k, q, q]**2 / W[q, q]**2 * W[q, p]
#                    G[p, q] += A[k, q, q] / W[q, q] * 2 * A[k, q, p] #- A[k, q, q]**2 / W[q, q]**2 * W[q, p] # like equations but seemingly problematic due to diagonal elements q, q
                    #G[p, q] += A[k, p, p] / W[p, p] * 2 * A[k, p, q] - A[k, p, p]**2 / W[p, p]**2 * W[p, q] # corrected version based on gut-feeling. multiplying second part with <1. seems to make it work
#                    G[p, q] += A[k, p, p] / W[p, p] * (2 * A[k, p, q] - 0.95 * A[k, p, p] * W[p, q] / W[p, p])
#                    G2[p, q] += A[k, p, p] * (A[k, p, q] + A[k, q, p])
                    G2[p, q] += A[k, q, q] * (A[k, q, p] + A[k, p, q])

#        if it % 1 == 0:
#            print(it, 'G-G2 = ', np.diag(G-G2)/np.diag(G))
#            print(it, 'G = ', G)
#            print(it, 'G2 = ', G2, '\n')
        U = np.eye(n)
        U += h * G

        # projection to unitary space
        U, _ = polar(U)
#        print('U=\n',U)
        U2 = np.eye(n) + h * G2
        U2, _ = polar(U2)
        U = U2
#        print('U2=\n',U2)

        # A' = U^T A U
        A2, W2 = transform_matrices_w(A, W, U)
        d2 = weighted_diagonal_norm(A2, W2)
        if out:
            print(it, h, d1, d2)

        if d2 > d1: # accept step
            A = A2.copy()
            W = W2.copy()
            if h < 1e1:
                h *= 1.5
            Utot = Utot @ U
            d1 = d2
            it += 1
            if it >= iterations:
                finish = True
        else: # reject step
            h /= 2
            if h < 1e-15:
                finish = True

    # make sure numerical errors are corrected
    Utot, _ = polar(Utot, side = 'left')
    d = weighted_diagonals(A, W)
    print('improvement: ', d1 / d0, '; d0 = ', d0, '; d1 = ', d1)
    print('convergence information: h = ', h, ', it = ', it)
    return d, Utot

def weighted_diag(A, W, U = None, h = 1e-1, micro_iterations = 100, macro_iterations = 3, out = False):
    Utot = np.eye(A.shape[1])
    for it in range(macro_iterations):
        d, U = riemannian_joint_diagonalization(A, W, h = 1e-2, iterations = micro_iterations, out = False)
        Utot = Utot @ U
        A, W = transform_matrices_w(A, W, U)
    return d, Utot
    

#def riemannian_joint_diagonalization(A, W, U = None, h = 1e-1, iterations = 1000, out = False):

def initial_A(n):
    A = np.random.rand(n, n, n)
    for k in range(n):
        A[k, :, :] = A[k, :, :] @ A[k, :, :].T
    A[1, :, :] = A[0, :, :] @ A[0, :, :]
    A[2, :, :] = A[0, :, :] @ A[1, :, :]
    for k in range(n):
        A[k, :, :] /= np.trace(A[k, :, :])
    return A

def initial_W(n):
    W = np.random.rand(n, n)
    W = W @ W.T
    for i in range(n):
        W[i, i] *= 5
    W /= np.trace(W)
    return W

def report(A, W, pre = None):
    print(pre + 'A =\n', A)
    print(pre + 'W =\n', W, np.trace(W))
    d = weighted_diagonals(A, W)
    print(pre + 'd =\n', d)

#W = initial_W(3)
W = np.eye(3)
A = initial_A(3)
A = weighted_symmetric(A, W)

report(A, W, pre = 'start ')

#d, Utot = riemannian_joint_diagonalization(A, W, out = False)
d, Utot = weighted_diag(A, W, macro_iterations = 5, micro_iterations=200)

A2, W2 = transform_matrices_w(A, W, Utot)

report(A2, W2, pre = 'end ')

start A =
 [[[0.44137516 0.22354484 0.28353952]
  [0.22354484 0.16625914 0.24639469]
  [0.28353952 0.24639469 0.3923657 ]]

 [[0.42867956 0.27116676 0.38425294]
  [0.27116676 0.18235184 0.26501015]
  [0.38425294 0.26501015 0.38896859]]

 [[0.42616414 0.27984045 0.40282577]
  [0.27984045 0.18557685 0.26820806]
  [0.40282577 0.26820806 0.38825901]]]
start W =
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]] 3.0
start d =
 [[0.44137516 0.16625914 0.3923657 ]
 [0.42867956 0.18235184 0.38896859]
 [0.42616414 0.18557685 0.38825901]]
improvement:  2.4304817484462533 ; d0 =  1.11151960706552 ; d1 =  2.7015281180128974
convergence information: h =  8.203947965609308e-16 , it =  75
improvement:  1.0 ; d0 =  2.7015281180128965 ; d1 =  2.7015281180128965
convergence information: h =  5.684341886080802e-16 , it =  0
improvement:  1.0 ; d0 =  2.7015281180128965 ; d1 =  2.7015281180128965
convergence information: h =  5.684341886080802e-16 , it =  0
improvement:  1.0 ; d0 =  2.7015281180128965 ; d1 =  2.7015281