In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from functools import reduce 
from torch import nn
from functorch import vmap
torch.set_default_dtype(torch.float64)
torch.manual_seed(1)
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = 

debug = False

def logabssumdet(xs):
		
	dets = [x.reshape(-1) for x in xs if x.shape[-1] == 1]						# in case n_u or n_d=1, no need to compute determinant
	dets = reduce(lambda a,b: a*b, dets) if len(dets)>0 else 1.					# take product of these cases
	maxlogdet = 0.																# initialised for sumlogexp trick (for stability)
	det = dets																	# if both cases satisfy n_u or n_d=1, this is the determinant
	
	slogdets = [torch.linalg.slogdet(x) for x in xs if x.shape[-1]>1] 			# otherwise take slogdet
	if len(slogdets)>0: 
		sign_in, logdet = reduce(lambda a,b: (a[0]*b[0], a[1]+b[1]), slogdets)  # take product of n_u or n_d!=1 cases
		maxlogdet = torch.max(logdet)											# adjusted for new inputs
		det = sign_in * dets * torch.exp(logdet-maxlogdet)						# product of all these things is determinant
	
	psi_ish = torch.sum(det)
	sgn_psi = torch.sign(psi_ish)
	log_psi = torch.log(torch.abs(psi_ish)) + maxlogdet
	return log_psi, sgn_psi


class FermiNetTorch(nn.Module):
	def __init__(self, *, n_e=None, n_u=None, n_d=None, n_det=None, n_fb=None, n_pv=None, n_sv=None, a=None, with_sign=False, **kw):
		super(FermiNetTorch, self).__init__()
		self.n_e = n_e                  # number of electrons
		self.n_u = n_u                  # number of up electrons
		self.n_d = n_d                  # number of down electrons
		self.n_det = n_det              # number of determinants
		self.n_fb = n_fb+1              # number of feedforward blocks
		self.n_pv = n_pv                # latent dimension for 2-electron
		self.n_sv = n_sv                # latent dimension for 1-electron
		self.a = torch.tensor(a)                      # nuclei positions
		self.with_sign = with_sign      # return sign of wavefunction

		self.n1 = [4*self.a.shape[0]] + [self.n_sv]*self.n_fb
		self.n2 = [4] + [self.n_pv]*(self.n_fb - 1)
		assert (len(self.n1) == self.n_fb+1) and (len(self.n2) == self.n_fb)
	
		self.Vs = nn.ModuleList([nn.Linear(3*self.n1[i]+2*self.n2[i], self.n1[i+1]) for i in range(self.n_fb)])
		self.Ws = nn.ModuleList([nn.Linear(self.n2[i], self.n2[i+1]) for i in range(self.n_fb-1)])

		self.V_half_u = nn.Linear(self.n_sv, self.n_sv // 2)
		self.V_half_d = nn.Linear(self.n_sv, self.n_sv // 2)

		self.wu = nn.Linear(self.n_sv // 2, self.n_u)
		self.wd = nn.Linear(self.n_sv // 2, self.n_d)

		# TODO: Multideterminant. If n_det > 1 we should map to n_det*n_u (and n_det*n_d) instead,
		#  and then split these outputs in chunks of n_u (n_d)
		# TODO: implement layers for sigma and pi

	def forward(self, r: torch.Tensor):
		"""
		Batch dimension is not yet supported.
		"""

		if len(r.shape) == 1:
			r = r.reshape(self.n_e, 3) # (n_e, 3)

		dtype=r.dtype

		eye = torch.eye(self.n_e, device=r.device, dtype=dtype).unsqueeze(-1)

		ra = r[:, None, :] - self.a[None, :, :] # (n_e, n_a, 3)
		ra_len = torch.norm(ra, dim=-1, keepdim=True) # (n_e, n_a, 1)

		rr = r[None, :, :] - r[:, None, :] # (n_e, n_e, 1)
		rr_len = torch.norm(rr + eye, dim=-1, keepdim=True) * (torch.ones((self.n_e, self.n_e, 1), dtype=dtype)-eye) # (n_e, n_e, 1) 
		# TODO: Just remove '+eye' from above, it's unnecessary

		s_v = torch.cat([ra, ra_len], dim=-1).reshape(self.n_e, -1) # (n_e, n_a*4)
		p_v = torch.cat([rr, rr_len], dim=-1) # (n_e, n_e, 4)

		for l, (V, W) in enumerate(zip(self.Vs, self.Ws)):
			sfb_v = [torch.tile(_v.mean(dim=0)[None, :], (self.n_e, 1)) for _v in torch.split(s_v, 2, dim=0)]
			pfb_v = [_v.mean(dim=0) for _v in torch.split(p_v, self.n_u, dim=0)]
			
			s_v = torch.cat(sfb_v+pfb_v+[s_v,], dim=-1) # s_v = torch.cat((s_v, sfb_v[0], sfb_v[1], pfb_v[0], pfb_v[0]), dim=-1)
			s_v = torch.tanh(V(s_v)) + (s_v if (s_v.shape[-1]==self.n_sv) else 0.)
			
			if not (l == (self.n_fb-1)):
				p_v = torch.tanh(W(p_v)) + (p_v if (p_v.shape[-1]==self.n_pv) else 0.)
		
		s_u, s_d = torch.split(s_v, self.n_u, dim=0)

		s_u = torch.tanh(self.V_half_u(s_u)) # spin dependent size reduction
		s_d = torch.tanh(self.V_half_d(s_d))

		s_wu = self.wu(s_u) # map to phi orbitals
		s_wd = self.wd(s_d)

		assert s_wd.shape == (self.n_d, self.n_d)

		ra_u, ra_d = torch.split(ra, self.n_u, dim=0)

		# TODO: implement sigma = nn.Linear() before this
		exp_u = torch.norm(ra_u, dim=-1, keepdim=True)
		exp_d = torch.norm(ra_d, dim=-1, keepdim=True)

		assert exp_d.shape == (self.n_d, self.a.shape[0], 1)

		# TODO: implement pi = nn.Linear() before this
		orb_u = (s_wu * (torch.exp(-exp_u).sum(axis=1)))[None, :, :]
		orb_d = (s_wd * (torch.exp(-exp_d).sum(axis=1)))[None, :, :]

		assert orb_u.shape == (1, self.n_u, self.n_u)

		log_psi, sgn = logabssumdet([orb_u, orb_d])

		if self.with_sign:
			return log_psi, sgn
		else:
			return log_psi.squeeze()


from pyfig import Pyfig
from utils import flat_any
from functorch import make_functional
from torch.autograd import gradgradcheck
from functools import partial
from functorch import hessian, grad, jacfwd, jacrev

arg = dict(
	n_b = 2,
	n_sv = 8,
	n_pv = 8,
	n_fb = 1,
)

c = Pyfig(wb_mode='disabled', arg=arg, submit=False, run_sweep=False)
d = flat_any(c.d)

r = torch.randn((c.data.n_b, c.data.n_e, 3))
model_b = FermiNetTorch(**d)
log_psi_b = model_b(r[0])
print(r[0], log_psi_b)

init PlugIn classes
updating configuration
{}
exp_path: dump/exp/junk-0/hDHVncD
running script
tensor([[-0.3113, -0.7130, -0.7291],
        [-0.2992, -0.2529, -0.3602],
        [ 0.9394,  1.1614,  0.0941],
        [ 0.5119,  0.5962,  1.2911]]) tensor(-19.2714, grad_fn=<SqueezeBackward0>)


In [5]:
def second_order(r_flat, step):
    step = torch.eye(n_e*3, dtype=r_flat.dtype).unsqueeze(0)*step
    return torch.stack([r_flat-step, r_flat+step])

r = torch.randn((8, c.data.n_e, 3))
n_b, n_e, _ = r.shape

r_flat = r.reshape(n_b, -1)
step = 0.00001
r_num = second_order(r_flat.unsqueeze(1), step=step)
print(r_num.shape)
    
r = torch.randn((8, c.data.n_e, 3))
n_b, n_e, _ = r.shape
r_flat = r.reshape(n_b, -1)
    

torch.Size([2, 8, 12, 12])


In [7]:
import difftorch
from functorch import hessian

model_fn, params = make_functional(model_b)
model_r = lambda _r: model_fn(params, _r)

b = 6
diff = difftorch.laplacian(model_r, r_flat[b])
print(diff)

def sum_diag(x):
    return (x*torch.eye(x.shape[-1])).sum()

from functorch import vjp
n_jvp = n_e*3

model_fn, params = make_functional(model_b)
model_fnv = vmap(model_fn, in_dims=(None, 0))
model_v = lambda _r: model_fnv(params, _r).sum()

grad_fn = grad(model_v)
grads = grad_fn(r_flat)
g, fn = vjp(grad_fn, r_flat)
e = torch.stack([fn(torch.eye(n_jvp).unsqueeze(0).repeat((n_b, 1, 1))[..., i])[0] for i in range(n_jvp)], -1)
e = torch.stack([fn(torch.eye(n_jvp).unsqueeze(0).repeat((n_b, 1, 1))[..., i])[0][:, i] for i in range(n_jvp)]).sum(0)

# jvp
eyes = torch.eye(n_jvp, )[None].repeat((n_b, 1, 1))
print(eyes.shape)
from functorch import jvp 
grad_fn = grad(model_v)
grads = grad_fn(r_flat)
jvp_all = [jvp(grad_fn, (r_flat,), (eyes[:, i],)) for i in range(n_jvp)]  # grad out, jvp
e_jvp = torch.stack([a[:, i] for i, (a,b) in enumerate(jvp_all)]).sum(0)[b]
print(e_jvp, grads[b].sum())


tensor(-14.2550)
torch.Size([8, 12, 12])
tensor(-0.8507, grad_fn=<SelectBackward0>) tensor(-0.8507, grad_fn=<SumBackward0>)


In [None]:
from pathlib import Path
Path('')

PosixPath('.')

In [None]:
print(e[b], e_jvp[b], diff)

tensor(-14.2550, grad_fn=<SelectBackward0>) tensor(-4.5207, grad_fn=<SelectBackward0>) tensor(-14.2550)


In [None]:

# hess = vmap(hessian(model_j), in_dims=(0,), out_dims=(0,))(r_flat)
# print((hess[b]*torch.eye(n_e*3)).sum())
# print(hess.shape)
# def grad_grad():
    
#     xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
#     xs_flat = torch.stack(xis, dim=1)
#     sign, ys = network(xs_flat.view_as(xs))
    
#     ones = torch.ones_like(ys)
#     (dy_dxs,) = torch.autograd.grad(ys, xs_flat, ones, retain_graph=True, create_graph=True)

#     lap_ys = sum(
#       torch.autograd.grad(
#           dy_dxi, xi, ones, retain_graph=True, create_graph=False)[0]
#       for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
#     )

# def inter(r):
#     return jacrev(model_r)

# model_v = vmap(model_fn, in_dims=(None, 0))

# def per_b(r):
#     return jacrev(model_j)(r).sum()

# print(per_b(r).shape)
# jac = jacfwd(per_b)(r)
# print(jac.shape)

# jac = vmap(jacrev(jacfwd(model_j)))(r_flat)
# r_flat = r.reshape(n_b, -1)
# print(r_flat.shape)
# vmap(jacfwd(model_j))(r_flat).shape


In [None]:
len(x)

NameError: name 'x' is not defined

In [None]:
model_fn, params = make_functional(model_b)

def fd2(r_flat, step=0.00001):
    model_fd2 = lambda _r: model_fn(params, _r)
    model_fd2 = vmap(vmap(vmap(model_fd2)))
    r_flat = r_flat.unsqueeze(1)
    step_eye = (torch.eye(n_e*3, dtype=r_flat.dtype)*step).unsqueeze(0)
    r_flat0 = torch.stack([r_flat-step_eye, torch.tile(r_flat, (1, n_e*3, 1)), r_flat+step_eye], dim=0)
    print(r_flat0.shape)
    log_psi = model_fd2(r_flat0)
    factor = torch.tensor([+1, -2, +1]).unsqueeze(-1).unsqueeze(-1)
    print(factor.shape, log_psi.shape)
    return (factor * log_psi).sum(0) / step**2

lap_fd2_diag = fd2(r_flat)

grad_fn = grad(lambda _r: model_fn(params, _r))

num_grad_fn = vmap(grad_fn)

r_num_flat = r_num.reshape(-1, n_e*3)
lap = num_grad_fn(r_num_flat)
lap = lap.reshape(2, n_b, n_e*3, n_e*3)
lap = (lap[1] - lap[0]) / (2.*step)
lap_diag = torch.diagonal(lap, dim2=1, dim1=2)

hess = vmap(hessian(model_fn, argnums=1), in_dims=(None, 0))(params, r_flat)
hess_diag = torch.diagonal(hess, dim2=1, dim1=2)

grad_fn = grad(lambda _r: model_fn(params, _r), )


# model_diff = vmap(lambda _r: model_fn(params, _r))

jac = vmap(jacfwd(jacrev(model_fndiff)))(r_flat)
print('jac', jac.shape)
jac_diag = torch.diagonal(jac, dim2=2, dim1=1)
# jac_diag = jac[:, :, 0]

# for b in range(n_b):
#     for h,j,n, fd in zip(hess_diag[b], jac_diag[b], lap_diag[b], lap_fd2_diag[b]):
#         print(h.item(), j.item(), n.item(), fd.item())
#     print()
# https://github.com/metaopt/torchopt

# https://github.com/DiffEqML/torchdyn/search?q=numerical

# https://pytorch.org/docs/stable/generated/torch.gradient.html

[print(jac_diag.sum(0)[i], hess_diag.sum(0)[i], lap_fd2_diag.sum(0)[i]) for i in range(n_e*3)]

from functorch import jvp, grad, vjp

def hvp(f, primals, tangents):
  return jvp(vmap(grad(f)), primals, tangents)

tangent = torch.tile(torch.eye(r_flat.shape[-1]).unsqueeze(0), (n_b, 1, 1))
res, hess_jvp = hvp(model_fndiff, (r_flat,), (tangent[..., b],))


print(diff, jac_diag[b].sum(), lap_diag[b].sum(), hess_jvp.sum(-1))
print(res.shape, hess_jvp.shape)

torch.Size([3, 8, 12, 12])
torch.Size([4, 3]) torch.float64
torch.Size([3, 1, 1]) torch.Size([3, 8, 12])
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
jac torch.Size([8, 12, 12])
tensor(8.8215, grad_fn=<SelectBackward0>) tensor(8.8215, grad_fn=<SelectBackward0>) tensor(-20.6967, grad_fn=<SelectBackward0>)
tensor(-15.3759, grad_fn=<SelectBackward0>) tensor(-15.3759, grad_fn=<SelectBackward0>) tensor(-797.2138, grad_fn=<SelectBackward0>)
tensor(25.2503, grad_fn=<SelectBackward0>) tensor(25.2503, grad_fn=<SelectBackward0>) tensor(-189.4682, grad_fn=<SelectBackward0>)
tensor(0.5131, grad_fn=<SelectBackward0>) tensor(0.5131, grad_fn=<SelectBackward0>) tensor(-49.3467, grad_fn=<SelectBackward0>)
tensor(0.9627, grad_fn=<SelectBackward0>) tensor(0.9627, grad_fn=<SelectBackward0>) tensor(-96.5911, grad_fn=<SelectBackward0>)
tensor(8.6673, grad_fn=<SelectBackward0>) tensor(8.6673, grad_fn=<SelectBackward0>) tensor(-47.2147, grad_fn=<SelectBack

In [None]:
print(jac[0].detach().numpy())

[[-0.73735464 -0.04523876  0.18050537 -0.01403398  0.13979426  0.05114419
   0.06405265 -0.01023321 -0.04863514  0.1416318  -0.03594161 -0.04314019]
 [-0.04523876 -1.20164201 -0.17737489  1.36905609 -1.50789619  0.94856467
   0.2367494   0.45422666  0.35480547 -0.32755232  0.76913821  0.00713631]
 [ 0.18050537 -0.17737489 -0.36160326 -0.12719147  0.00840817  0.09334353
  -0.01633981  0.00306447 -0.04423238 -0.00429295 -0.06415162 -0.05690243]
 [-0.01403398  1.36905609 -0.12719147 -1.68073671  2.73843645 -2.27850799
  -0.36497873 -0.8104578   0.02620897  0.1911337  -1.36621224 -0.08696205]
 [ 0.13979426 -1.50789619  0.00840817  2.73843645 -2.61672075  2.33323866
   0.23590811  0.28771572  0.5011895  -0.5732838   1.48474532  0.33816889]
 [ 0.05114419  0.94856467  0.09334353 -2.27850799  2.33323866 -0.72148591
   0.41583662 -0.513669   -0.6055781   0.38099589 -0.88829983 -0.42335363]
 [ 0.06405265  0.2367494  -0.01633981 -0.36497873  0.23590811  0.41583662
  -1.09309984 -0.11688901 -0.702

In [None]:
print(lap_diag.shape)


torch.Size([2, 12])


In [None]:
print(hess.shape, lap.shape)
print(hess_diag.shape, lap_diag.shape)

torch.Size([2, 12, 12]) torch.Size([2, 12, 12])
torch.Size([12, 2]) torch.Size([12, 2])


In [None]:
print(lap_diag[:, 1], hess_diag[:, 1])

tensor([-2.4185e+00, -4.2200e+00, -1.3236e+00, -2.4633e+00, -4.1678e+00,
        -2.5564e+00,  4.8379e+00, -6.6815e-01, -5.2904e-01, -1.9204e-03,
         9.8526e+00,  1.5134e+01], grad_fn=<SelectBackward0>) tensor([-0.0375, -0.6429,  0.0996,  0.0645, -0.1103,  0.0589,  0.4182,  0.2292,
         0.1855,  0.1374,  2.2264,  0.1242], grad_fn=<SelectBackward0>)


In [None]:
# does not work - batching rule not implemented for concatenate


torch.Size([4, 3]) torch.float64
torch.Size([2, 12, 12])


tensor([[-1.4323e+00,  2.4399e+00],
        [ 2.4498e-01, -3.4506e+00],
        [-2.1750e-01, -9.1636e-01],
        [-3.2714e-02,  4.0783e-01],
        [ 1.6882e-02,  9.7052e-01],
        [ 5.7901e-03,  2.7131e+00],
        [ 3.9595e-02, -2.4113e-01],
        [ 1.7171e-03, -9.2003e-03],
        [ 3.6682e-02,  2.6252e-02],
        [ 7.5842e-04, -1.7302e-03],
        [ 8.9144e-03,  2.2968e-01],
        [-2.1564e-02, -2.6287e-01]], grad_fn=<DiagonalBackward0>)

In [None]:
# does not work - batching rule not implemented for concatenate
from functorch import hessian, grad, jacfwd

r = torch.randn((c.data.n_b, c.data.n_e, 3))
model_fn, params = make_functional(model_b)
model_fn = vmap(model_fn, in_dims=(None, 0))

grad_fn = grad(lambda r: model_fn(params, r).sum())
gr = grad_fn(r)
# 
print(gr)
print(gr)

def stack_grad(*r_jvp):
    print(r_jvp)
    r = torch.concatenate(r_jvp, dim=-1)
    return grad_fn(r).sum(0)
    
r_jvp = r.reshape(c.data.n_b, -1).split(1, dim=-1)
n_jvp = len(r_jvp)
print(len(r_jvp), )
lap = None
for i in range(n_jvp):
    jac_vp = jacfwd(vmap(stack_grad), argnums=i)(*r_jvp)
    print(jac_vp.shape)
    lap = jac_vp if lap is None else torch.cat([lap, jac_vp], dim=-1)

torch.Size([4, 3]) torch.float64
tensor([[[ 1.3037,  0.2695, -0.1894],
         [-1.0599, -0.0398, -0.7739],
         [ 0.8220, -0.3895, -0.6427],
         [-0.1901, -0.9625,  0.9182]],

        [[-1.1996, -0.2136, -0.9104],
         [-0.0142, -0.1147,  0.8459],
         [ 1.6933,  0.0769,  3.4752],
         [-0.2476, -0.1269, -1.4170]]], grad_fn=<AddBackward0>)
tensor([[[ 1.3037,  0.2695, -0.1894],
         [-1.0599, -0.0398, -0.7739],
         [ 0.8220, -0.3895, -0.6427],
         [-0.1901, -0.9625,  0.9182]],

        [[-1.1996, -0.2136, -0.9104],
         [-0.0142, -0.1147,  0.8459],
         [ 1.6933,  0.0769,  3.4752],
         [-0.2476, -0.1269, -1.4170]]], grad_fn=<AddBackward0>)
12
(BatchedTensor(lvl=3, bdim=0, value=
    GradTrackingTensor(lvl=2, value=
        tensor([[-0.6687],
                [ 0.2903]])
    )
), BatchedTensor(lvl=3, bdim=0, value=
    GradTrackingTensor(lvl=2, value=
        tensor([[-0.3548],
                [ 0.7108]])
    )
), BatchedTensor(lvl=3, bdim

RuntimeError: Batching rule not implemented for aten::concatenate. We could not generate a fallback.

In [None]:
print(lap.shape)


torch.Size([2, 12, 2, 12])


In [None]:
torch.concatenate(gr, -1).shape

torch.Size([2, 12])

In [None]:
# jac = laplacian(), argnums=1)


# print(grad_fn(r))




# log_psi = model_fn(params)


# def model(r):
#     return model_fn(params, r)
    
# model = vmap(model)
# print(r.shape)
# r_hess = r.reshape(c.data.n_b, -1)
# print(r_hess.shape)

# log_psi = model(r)
# log_psi_hess = model(r_hess)
# print(log_psi)
# print(log_psi_hess)


# r = torch.randn((c.data.n_b, c.data.n_e, 3))
# r_hess = r.reshape(c.data.n_b, -1)
# model_fn, params = make_functional(model_b)
# hess = vmap(hessian(model))(r_hess)
# print(hess.shape)

In [None]:
# https://pytorch.org/functorch/stable/ux_limitations.html#vmap-limitations
# finite difference jacobian matrix

# https://pytorch.org/docs/stable/generated/torch.autograd.gradgradcheck.html#torch.autograd.gradgradcheck

_r = r.reshape(c.data.n_b, -1).requires_grad_()
dpsidr2 = gradgradcheck(model, inputs=_r[[0]])  #  The gradients with respect to the function’s outputs
print(dpsidr2)

True


In [None]:
https://discuss.pytorch.org/t/fast-computation-of-the-hessian-diagonal/143145

from time import time

import torch
from torch._vmap_internals import _vmap as vmap

def grad(outputs, inputs, **kwargs):
    """
    Compute the trace of the jacobian of a 3D tensor.

    grad_{ijk} = d (outputs_{ijk}) / d (inputs_{ijk})

    See: https://discuss.pytorch.org/t/jacobian-functional-api-batch-respecting-jacobian/84571/7?u=amerlo94
    """

    shape = outputs.shape
    bs = shape[0]
    n = shape[1] * shape[2]

    outputs = outputs.view(bs, n)
    outputs = outputs.sum(axis=0)
    grad_outputs = torch.eye(n, dtype=outputs.dtype, device=outputs.device)

    def get_vjp(v):
        return torch.autograd.grad(outputs, inputs, v, **kwargs)[0]

    vjp = vmap(get_vjp)
    grad = vjp(grad_outputs)

    return grad.T.view(shape)


n = 16
b = torch.randn(1, n, n)


def f(x):
    return x ** 2 * b


bs = 128
iterations = 1000

x = torch.rand(bs).requires_grad_(True)
x = x.view(-1, 1, 1)
y = f(x)

#  Check gradients
jac = grad(y, x, create_graph=True)
hes = grad(jac, x)
assert torch.allclose(jac, 2 * x * b)
assert torch.allclose(hes, 2 * b)

#  Time gradients
t0 = time()
for _ in range(iterations):
    jac = grad(y, x, create_graph=True)
    grad(jac, x)

tottime = time() - t0

time_per_batch = tottime / iterations / bs
print(f"Time per batch: {time_per_batch * 1e6:.2f} us")

In [None]:
# https://github.com/amirgholami/adahessian
# second order opt

# jacobian of the jacobian
# https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7

In [None]:

n_jvp = c.data.n_e * 3
laplacian = torch.zeros(r.shape[0]) #array to store values of laplacian

for i, xi in enumerate(r):
	hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
	laplacian[i] = torch.diagonal(hess.view(n_jvp, n_jvp), offset=0).sum()

In [None]:
from functools import jacrev

def kinetic_functorch(r):
  calc_jacobian = jacrev(log_psi, argnums=1) #do once, and re-use?
  calc_hessian = jacrev(jacrev(log_psi, argnums=1), argnums=1)
  return -0.5*torch.sum(calc_hessian(params, r).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, r).squeeze(-1).pow(2), dim=-1)

#per-sample gradients for local energy w.r.t params via FuncTorch
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
elocal_grad_ft = [p.clone().mean(dim=0) for p in elocal_grad_ft]


def kinetic_pytorch(xs: Tensor) -> Tensor:
  """Method to calculate the local kinetic energy values of a netork function, f, for samples, x.
  The values calculated here are 1/f d2f/dx2 which is equivalent to d2log(|f|)/dx2 + (dlog(|f|)/dx)^2
  within the log-domain (rather than the linear-domain).
  :param xs: The input positions of the many-body particles
  :type xs: class: `torch.Tensor`
  """
  xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
  xs_flat = torch.stack(xis, dim=1)

  _, ys = net(xs_flat.view_as(xs))
  #print("pytorch logabs: ",ys)
  ones = torch.ones_like(ys)

  #df_dx calculation
  (dy_dxs, ) = torch.autograd.grad(ys, xs_flat, ones, retain_graph=True, create_graph=True)


  #d2f_dx2 calculation
  lay_ys = sum(torch.autograd.grad(dy_dxi, xi, ones, retain_graph=True, create_graph=False)[0] \
                for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
  )
  
  ek_local_per_walker = -0.5 * (lay_ys + dy_dxs.pow(2).sum(-1)) #move const out of loop?
  return ek_local_per_walker

#PyTorch gradients via reverse-mode AD
net.zero_grad()
kin_pt = kinetic_pytorch(x)
loss = torch.mean(kin_pt)
loss.backward()
elocal_grad_pt = [param.grad for param in net.parameters()]

energy_ft = kinetic_functorch(params, x)
energy_pt = kinetic_pytorch(x)

print(energy_ft)

torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size([4, 3]) torch.float64
torch.Size

KeyboardInterrupt: 

In [None]:
# https://discuss.pytorch.org/t/how-to-calculate-laplacian-for-multiple-batches-in-parallel/104888

y = model(x) #where model is an R^N to R^1 function

laplacian = torch.zeros(x.shape[0]) #array to store values of laplacian

for i, xi in enumerate(x):
    hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
    laplacian[i] = torch.diagonal(hess.view(N, N) offset=0).sum()

In [None]:
# https://pytorch.org/docs/stable/generated/torch.autograd.functional.hvp.html
# Hessian vector product vs grad * jvp 

# https://colab.research.google.com/drive/13xuYhabehAqJAalT-wufTxlCOr3VmiRU
# faster hessian vector product

# https://discuss.pytorch.org/t/explicitly-calculate-jacobian-matrix-in-simple-neural-network/133670/4
# discussion on how to get the fwdjac 
# It is technically possible now, depending on the operators that your model uses. 
# However, the real speed-up of forward-over-reverse Hessian comes from being able to vectorize over the forward 
# (otherwise you’d have to compute the forward O(numel) times). The ability to compute vectorized jvp should be in master soon, but is not ready at the moment.

# $ https://discuss.pytorch.org/t/notimplementederror-you-must-implement-the-jvp-function-for-custom-autograd-function-to-use-it-with-forward-mode-ad/138245/2
# forward mode ad jvp 

# 
# def laplacian(xs, f, create_graph=False, keep_graph=None, return_grad=False):
    # xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
    # xs_flat = torch.stack(xis, dim=1)
    # ys = f(xs_flat.view_as(xs))
    # (ys_g, *other) = ys if isinstance(ys, tuple) else (ys, ())
    # ones = torch.ones_like(ys_g)
    # (dy_dxs,) = torch.autograd.grad(ys_g, xs_flat, ones, create_graph=True)
    # lap_ys = sum(
    #     torch.autograd.grad(
    #         dy_dxi, xi, ones, retain_graph=True, create_graph=create_graph
    #     )[0]
    #     for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
    # )
    # if not (create_graph if keep_graph is None else keep_graph):
    #     ys = (ys_g.detach(), *other) if isinstance(ys, tuple) else ys.detach()
    # result = lap_ys, ys
    # if return_grad:
    #     result += (dy_dxs.detach().view_as(xs),)
    # return result
    
    
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.functional.hessian
# pytorch hessian vector product
