In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from functools import reduce 
from torch import nn
from functorch import vmap
torch.set_default_dtype(torch.float64)
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = 


def logabssumdet(xs):
		
		dets = [x.reshape(-1) for x in xs if x.shape[-1] == 1]						# in case n_u or n_d=1, no need to compute determinant
		dets = reduce(lambda a,b: a*b, dets) if len(dets)>0 else 1.					# take product of these cases
		maxlogdet = 0.																# initialised for sumlogexp trick (for stability)
		det = dets																	# if both cases satisfy n_u or n_d=1, this is the determinant
		
		slogdets = [torch.linalg.slogdet(x) for x in xs if x.shape[-1]>1] 			# otherwise take slogdet
		if len(slogdets)>0: 
			sign_in, logdet = reduce(lambda a,b: (a[0]*b[0], a[1]+b[1]), slogdets)  # take product of n_u or n_d!=1 cases
			maxlogdet = torch.max(logdet)											# adjusted for new inputs
			det = sign_in * dets * torch.exp(logdet-maxlogdet)						# product of all these things is determinant
		
		psi_ish = torch.sum(det)
		sgn_psi = torch.sign(psi_ish)
		log_psi = torch.log(torch.abs(psi_ish)) + maxlogdet
		return log_psi, sgn_psi


class FermiNetTorch(nn.Module):
	def __init__(self, *, n_e=None, n_u=None, n_d=None, n_det=None, n_fb=None, n_pv=None, n_sv=None, a=None, with_sign=False, **kw):
		super(FermiNetTorch, self).__init__()
		self.n_e = n_e                  # number of electrons
		self.n_u = n_u                  # number of up electrons
		self.n_d = n_d                  # number of down electrons
		self.n_det = n_det              # number of determinants
		self.n_fb = n_fb                # number of feedforward blocks
		self.n_pv = n_pv                # latent dimension for 2-electron
		self.n_sv = n_sv                # latent dimension for 1-electron
		self.a = torch.tensor(a)                      # nuclei positions
		self.with_sign = with_sign      # return sign of wavefunction

		self.n1 = [4*self.a.shape[0]] + [self.n_sv]*self.n_fb
		self.n2 = [4] + [self.n_pv]*(self.n_fb - 1)
		assert (len(self.n1) == self.n_fb+1) and (len(self.n2) == self.n_fb)
	
		self.Vs = nn.ModuleList([nn.Linear(3*self.n1[i]+2*self.n2[i], self.n1[i+1]) for i in range(self.n_fb)])
		self.Ws = nn.ModuleList([nn.Linear(self.n2[i], self.n2[i+1]) for i in range(self.n_fb-1)])

		self.V_half_u = nn.Linear(self.n_sv, self.n_sv // 2)
		self.V_half_d = nn.Linear(self.n_sv, self.n_sv // 2)

		self.wu = nn.Linear(self.n_sv // 2, self.n_u)
		self.wd = nn.Linear(self.n_sv // 2, self.n_d)

		# TODO: Multideterminant. If n_det > 1 we should map to n_det*n_u (and n_det*n_d) instead,
		#  and then split these outputs in chunks of n_u (n_d)
		# TODO: implement layers for sigma and pi

	def forward(self, r: torch.Tensor):
		"""
		Batch dimension is not yet supported.
		"""

		if len(r.shape) == 1:
			r = r.reshape(self.n_e, 3) # (n_e, 3)

		dtype=r.dtype
		print(r.shape, r.dtype)

		eye = torch.eye(self.n_e, device=r.device, dtype=dtype).unsqueeze(-1)

		ra = r[:, None, :] - self.a[None, :, :] # (n_e, n_a, 3)
		ra_len = torch.norm(ra, dim=-1, keepdim=True) # (n_e, n_a, 1)

		rr = r[None, :, :] - r[:, None, :] # (n_e, n_e, 1)
		rr_len = torch.norm(rr + eye, dim=-1, keepdim=True) * (torch.ones((self.n_e, self.n_e, 1), dtype=dtype)-eye) # (n_e, n_e, 1) 
		# TODO: Just remove '+eye' from above, it's unnecessary

		s_v = torch.cat([ra, ra_len], dim=-1).reshape(self.n_e, -1) # (n_e, n_a*4)
		p_v = torch.cat([rr, rr_len], dim=-1) # (n_e, n_e, 4)

		for l, (V, W) in enumerate(zip(self.Vs, self.Ws)):
			sfb_v = [torch.tile(_v.mean(dim=0)[None, :], (self.n_e, 1)) for _v in torch.split(s_v, 2, dim=0)]
			pfb_v = [_v.mean(dim=0) for _v in torch.split(p_v, self.n_u, dim=0)]
			
			s_v = torch.cat(sfb_v+pfb_v+[s_v,], dim=-1) # s_v = torch.cat((s_v, sfb_v[0], sfb_v[1], pfb_v[0], pfb_v[0]), dim=-1)
			s_v = torch.tanh(V(s_v)) + (s_v if (s_v.shape[-1]==self.n_sv) else 0.)
			
			if not (l == (self.n_fb-1)):
				p_v = torch.tanh(W(p_v)) + (p_v if (p_v.shape[-1]==self.n_pv) else 0.)
		
		s_u, s_d = torch.split(s_v, self.n_u, dim=0)

		s_u = torch.tanh(self.V_half_u(s_u)) # spin dependent size reduction
		s_d = torch.tanh(self.V_half_d(s_d))

		s_wu = self.wu(s_u) # map to phi orbitals
		s_wd = self.wd(s_d)

		assert s_wd.shape == (self.n_d, self.n_d)

		ra_u, ra_d = torch.split(ra, self.n_u, dim=0)

		# TODO: implement sigma = nn.Linear() before this
		exp_u = torch.norm(ra_u, dim=-1, keepdim=True)
		exp_d = torch.norm(ra_d, dim=-1, keepdim=True)

		assert exp_d.shape == (self.n_d, self.a.shape[0], 1)

		# TODO: implement pi = nn.Linear() before this
		orb_u = (s_wu * (torch.exp(-exp_u).sum(axis=1)))[None, :, :]
		orb_d = (s_wd * (torch.exp(-exp_d).sum(axis=1)))[None, :, :]

		assert orb_u.shape == (1, self.n_u, self.n_u)

		log_psi, sgn = logabssumdet([orb_u, orb_d])

		if self.with_sign:
			return log_psi, sgn
		else:
			return log_psi.squeeze()


from pyfig import Pyfig
from utils import flat_any
from functorch import make_functional
from torch.autograd import gradgradcheck
from functools import partial

c = Pyfig(wb_mode='disabled', submit=False, run_sweep=False)
d = flat_any(c.d)

r = torch.randn((c.data.n_b, c.data.n_e, 3))
model_b = FermiNetTorch(**d)
log_psi_b = model_b(r[0])

print(r[0], log_psi_b)

In [None]:

model, params = make_functional(model_b)
model = vmap(partial(model, params))
log_psi = model(r)

n_jvp = c.data.n_e * 3
laplacian = torch.zeros(r.shape[0]) #array to store values of laplacian

for i, xi in enumerate(r):
	hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
	laplacian[i] = torch.diagonal(hess.view(n_jvp, n_jvp), offset=0).sum()

In [None]:
from functools import jacrev

def kinetic_functorch(params, r):
  calc_jacobian = jacrev(logabs, argnums=1) #do once, and re-use?
  calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1)
  return -0.5*torch.sum(calc_hessian(params, r).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, r).squeeze(-1).pow(2), dim=-1)

#per-sample gradients for local energy w.r.t params via FuncTorch
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
elocal_grad_ft = [p.clone().mean(dim=0) for p in elocal_grad_ft]


def kinetic_pytorch(xs: Tensor) -> Tensor:
  """Method to calculate the local kinetic energy values of a netork function, f, for samples, x.
  The values calculated here are 1/f d2f/dx2 which is equivalent to d2log(|f|)/dx2 + (dlog(|f|)/dx)^2
  within the log-domain (rather than the linear-domain).
  :param xs: The input positions of the many-body particles
  :type xs: class: `torch.Tensor`
  """
  xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
  xs_flat = torch.stack(xis, dim=1)

  _, ys = net(xs_flat.view_as(xs))
  #print("pytorch logabs: ",ys)
  ones = torch.ones_like(ys)

  #df_dx calculation
  (dy_dxs, ) = torch.autograd.grad(ys, xs_flat, ones, retain_graph=True, create_graph=True)


  #d2f_dx2 calculation
  lay_ys = sum(torch.autograd.grad(dy_dxi, xi, ones, retain_graph=True, create_graph=False)[0] \
                for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
  )
  
  ek_local_per_walker = -0.5 * (lay_ys + dy_dxs.pow(2).sum(-1)) #move const out of loop?
  return ek_local_per_walker

#PyTorch gradients via reverse-mode AD
net.zero_grad()
kin_pt = kinetic_pytorch(x)
loss = torch.mean(kin_pt)
loss.backward()
elocal_grad_pt = [param.grad for param in net.parameters()]

energy_ft = kinetic_functorch(params, x)
energy_pt = kinetic_pytorch(x)

print(energy_ft)

In [None]:
# https://pytorch.org/functorch/stable/ux_limitations.html#vmap-limitations

# https://pytorch.org/docs/stable/generated/torch.autograd.gradgradcheck.html#torch.autograd.gradgradcheck

dpsidr2 = gradgradcheck(model, r)

  


In [None]:
# https://discuss.pytorch.org/t/how-to-calculate-laplacian-for-multiple-batches-in-parallel/104888

y = model(x) #where model is an R^N to R^1 function

laplacian = torch.zeros(x.shape[0]) #array to store values of laplacian

for i, xi in enumerate(x):
    hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
    laplacian[i] = torch.diagonal(hess.view(N, N) offset=0).sum()

In [None]:
# https://pytorch.org/docs/stable/generated/torch.autograd.functional.hvp.html
# Hessian vector product vs grad * jvp 