In [None]:
!python -m pip install --upgrade setuptools
!pip install git+https://github.com/tky823/ssspy.git

In [1]:
import time

In [2]:
import numpy as np

In [3]:
from ssspy.linalg import invsqrtmh, sqrtmh, eigh

In [4]:
rng = np.random.default_rng(0)
size = (16, 100, 4, 1)

In [5]:
def create_psd(rng):
    x = rng.random(size) + 1j * rng.random(size)
    XX = x * x.transpose(0, 1, 3, 2).conj()

    return np.mean(XX, axis=0)

In [6]:
def gmeanmh1(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    A_sqrt = sqrtmh(A)
    A_invsqrtmh = invsqrtmh(A)

    X = A_sqrt @ sqrtmh(A_invsqrtmh @ B @ A_invsqrtmh) @ A_sqrt

    return X

In [7]:
def gmeanmh2(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    lamb, Z = eigh(B, A)
    lamb = np.sqrt(lamb)
    Lamb = lamb[..., np.newaxis] * np.eye(Z.shape[-1])
    AB_sqrt = Z @ Lamb @ np.linalg.inv(Z)

    return A @ AB_sqrt

In [8]:
times1 = 0
times2 = 0

for _ in range(1000):
    A = create_psd(rng)
    B = create_psd(rng)

    t1 = time.perf_counter()
    G1 = gmeanmh1(A, B)
    t2 = time.perf_counter()
    G2 = gmeanmh2(A, B)
    t3 = time.perf_counter()

    times1 += t2 - t1
    times2 += t3 - t2

    assert np.allclose(G1, G2)

print(times1, times2)

5.393258904000305 4.754580001000022
