In [3]:
import mlx.core as mx

## Girard-Hutchinson Estimator

In [4]:
# we must have that $E[xx^T] = I$. We can use (0, 1) Gaussian random vectors or Rademacher random vectors
def gaussian_random(shape):
    # by default mx.normal is (0,1). So this is mostly pointless!
    sampled_rv = mx.random.normal(shape=shape, dtype=mx.float32)
    
    return sampled_rv

# The idea for Rademacher random vectors is from: https://docs.backpack.pt/en/master/use_cases/example_trace_estimation.html
def rademacher_random(shape):
    rand = mx.random.bernoulli(p=0.5, shape=shape)

    sampled_rv = 2 * rand.astype(mx.float32) - 1

    return sampled_rv


# TODO: make it a key? thingy for which distribution to use
def hutch_tr(mat, num_queries):

    # collect trace estimates

    n, _ = mat.shape
    tr_ests = mx.zeros(shape=num_queries)
    # This is really ugly and slow. Just directly project!
    for i in range(num_queries):
        x = rademacher_random(shape=(1, n))
        hutch_est = mx.matmul(x, mx.matmul(mat, mx.transpose(x)))
        tr_ests[i] = (hutch_est)

    return mx.mean(tr_ests)
        
    

In [5]:
B = mx.random.normal(shape=(100,100))
A = mx.matmul(mx.transpose(B), B)
A_ground_truth = mx.trace(A)
A_ground_truth

array(9927.53, dtype=float32)

In [6]:
A_est = hutch_tr(A, 10)
A_est

array(9940.48, dtype=float32)

## Hutch++ Estimator

In [1]:
"""
Recall that the rough design of the algorithm is as follows:
(credit to Dr. Meyer on his blog: ram900.com)
1. Find a good low-rank approximation $\tilde{A}_k$
2. Notice that tr($A$) = tr($\tilde{A}_k$) + tr($A - \tilde{A}_k$)
3. Compute tr($\tilde{A}_k$) exactly
4. Approximate tr($A - \tilde{A}_k$) with Hutchinson's Estimator
5. Returch Hutch++: tr($\tilde{A}_k$) + $H_l(A - \tilde{A}_k)$
"""


# somewhere, this kernel is dying.
# new project: add support for general QR!
def hutchpp_tr(mat, num_queries):
    n, _ = mat.shape

    # sample S \in R^{n \times (k + p) with N(0, 1) entries. Let Q be any orthonormal basis for AS.
    # we are saying that num_queries // 3 = p + k, i.e., 3k+1 since p = 2k+1. 
    # Then obviously higher num_queries leads to better estimate
    
    S = gaussian_random((n, num_queries // 3))
    x, z = S.shape
    m = x - z
    G = mx.zeros(shape=(n, m))
    M = mx.concatenate([mx.matmul(mat, S), G], axis=1)
    Q, _ = mx.linalg.qr(M)
    hutch = mx.trace(mx.matmul(mx.transpose(Q), mx.matmul(mat, Q)))
    hutch += hutch_tr(mx.matmul(mx.identity(n) - mx.matmul(Q, mx.transpose(Q)), mat), num_queries)

    return hutch
    

In [2]:
A_better_est = hutchpp_tr(A, 100)
A_better_est

NameError: name 'A' is not defined