In [41]:
import mlx.core as mx

## Girard-Hutchinson Estimator

In [42]:
# we must have that $E[xx^T] = I$. We can use (0, 1) Gaussian random vectors or Rademacher random vectors
def gaussian_random_vec(shape):
    # by default mx.normal is (0,1). So this is mostly pointless!
    sampled_rv = mx.random.normal(shape=shape, dtype=mx.float32)
    
    return sampled_rv

# The idea for Rademacher random vectors is from: https://docs.backpack.pt/en/master/use_cases/example_trace_estimation.html
def rademacher_random_vec(shape):
    rand = mx.random.bernoulli(p=0.5, shape=shape)

    sampled_rv = 2 * rand.astype(mx.float32) - 1

    return sampled_rv


# TODO: make it a key? thingy for which distribution to use
def hutch_tr(mat, num_samples):

    # collect trace estimates

    n, _ = mat.shape
    tr_ests = mx.zeros(shape=num_samples)
    for i in range(num_samples):
        x = rademacher_random_vec(shape=(1, n))
        hutch_est = mx.matmul(x, mx.matmul(mat, mx.transpose(x)))
        tr_ests[i] = (hutch_est)

    return mx.mean(tr_ests)
        
    

In [43]:
A = mx.random.randint(low=-100,high=100,shape=(100, 100),dtype=mx.int16)

In [44]:
A_ground_truth = mx.trace(A)
A_ground_truth

array(326, dtype=int16)

In [None]:
A_est = hutch_tr(A, 100000)
A_est