In [2]:
import torch
def f(x):
    return x**2+2*x+torch.exp(x)+torch.sin(x)

In [1]:
def hutch_trace(f,x,m=100):
    repeat_shape = [m]+[1]*(len(x.shape)-1)
    x_mult = x.repeat(repeat_shape).detach()
    x_mult.requires_grad_(True)
    eps = torch.randint(0,2,[m*x.shape[0]]+list(x.shape)[1:]).float()*2-1
    eps = eps.to(x.device)
    #eps = torch.randn([m*x.shape[0]]+list(x.shape)[1:])
    f_e = torch.sum(f(x_mult) * eps)
    grad_f_e = torch.autograd.grad(f_e, x_mult)[0]*eps
    grad_f_e = grad_f_e.reshape(m,x.shape[0],-1)
    return torch.mean(torch.sum(grad_f_e, dim=-1),dim=0)

In [19]:
def hutchpp(f, x, m):
    """https://arxiv.org/abs/2010.09649
    """
    def grad_eps(eps):
        #assume eps has shape (x.shape,m),
        #the first dimension is the batch dimension
        #return shape (x.shape,m)

        m = eps.shape[-1]
        # move the last dimension to the front
        eps = eps.permute(-1,0,*range(1,len(eps.shape)-1))
        eps = eps.reshape(m*x.shape[0],*x.shape[1:])
        repeat_shape = [m]+[1]*(len(x.shape)-1)
        x_mult = x.repeat(repeat_shape).detach()
        x_mult.requires_grad_(True)
        f_e = torch.sum(f(x_mult) * eps)
        grad_f_e = torch.autograd.grad(f_e, x_mult)[0]*eps
        grad_f_e = grad_f_e.reshape(m,x.shape[0],-1).permute(1,2,0)
        return grad_f_e

    def batch_trace(A):
        return torch.mean(torch.diagonal(A,dim1=-2,dim2=-1),dim=-1)

    d = torch.prod(torch.tensor(x.shape[1:]))
    S = torch.randn(d, m // 3)
    G = torch.randn(d, m // 3)
    f_e = f(x)
    eps = torch.randn(list(x.shape)+[m//3])
    grad_f_e = grad_eps(eps)
    
    Q, _ = torch.qr(grad_f_e)
    Q_T = Q.permute(0,2,1)
    proj = G - Q @ (Q_T @ G)
    proj_T = proj.permute(0,2,1)
    print((Q_T @ grad_eps(Q)).shape)
    return batch_trace(Q_T @ grad_eps(Q)) + (3./m)*batch_trace(proj_T @ grad_eps(proj))

In [3]:
import string
def exact_trace(f,x):
      alphabet = string.ascii_lowercase
      dydx = torch.vmap(torch.func.jacrev(f))(x)
      num_dims = len(x.shape)-1
      # Check if the number of dimensions is within the allowed range for einsum
      if num_dims > len(alphabet):
            raise ValueError("Number of dimensions exceeds einsum's capability")
      # Construct the einsum string for the given number of dimensions
      # The last two dimensions are supposed to be equal and summed over
      einsum_str = alphabet[num_dims] + alphabet[:num_dims] + alphabet[:num_dims] + '->' + alphabet[num_dims]
      return torch.einsum(einsum_str,dydx)

In [6]:
x = torch.randn(5,4,2,3)
hutch_trace(f,x,1)

tensor([ 98.9216,  85.3511, 102.1591, 147.3731, 102.6866])

In [5]:
exact_trace(f,x)

tensor([ 97.2051,  78.1516, 127.0718, 109.3323, 109.8828])

In [20]:
hutchpp(f,x,900)

torch.Size([5, 24, 24])


tensor([-0.0096, -0.1784,  0.8788, -0.4328, -0.3888])

In [7]:
print(2*x.sum(dim=(1,2))+12+torch.exp(x).sum(dim=(1,2))+torch.cos(x).sum(dim=(1,2)))

tensor([16.2197, 22.6289, 18.0896, 21.6987,  9.9557])
