In [1]:
import functools

import einops
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import netket as nk

from models.slaternet import SlaterNet
from models.psi_solid import PsiSolid

from systems.continuous import moire

key = jax.random.PRNGKey(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
subkey, key = jax.random.split(key)

wavefn = SlaterNet(
    num_particle=8, 
    recip_latt_vecs=moire.G, 
    hidden_dim=64,
    mlp_depth=8,
    key=subkey,
)


In [3]:
subkey, key = jax.random.split(key)

wavefn = PsiSolid(
    num_particle=8, 
    recip_latt_vecs=moire.G, 
    hidden_dim=64,
    attention_dim=16,
    num_heads=6, 
    num_blocks=3, 
    num_det=4,
    key=subkey,
)

In [4]:
subkey, key = jax.random.split(key)
R = jax.random.normal(subkey, (8, 2))
psi = wavefn(R)
psi.dtype, psi.shape

(dtype('complex128'), ())

In [5]:
dpsi_dx = nk.jax.grad(wavefn)(R)
dpsi_dx.dtype, dpsi_dx.shape

(dtype('complex128'), (8, 2))

In [6]:
psi_r = lambda x: jnp.real(wavefn(x))
jacobian = jax.jacobian(jax.grad(psi_r))
jac = jacobian(R)
jac.shape

(8, 2, 8, 2)

In [7]:
def wfn_laplacian(
    wavefn: eqx.Module,
    point
):
    def laplacian(fn_real):
        grad_fn = jax.grad(fn_real)
        
        def hvp(v):
            _, hv = jax.jvp(grad_fn, (point,), (v,))
            return hv
        
        d = point.size # n_par * spc_dim
        
        def body(i, acc):
            v_flat = jnp.zeros(d, point.dtype).at[i].set(1.)
            v = v_flat.reshape(point.shape)
            hv = hvp(v)
            diag_i = hv.reshape(-1)[i]
            return acc + diag_i
        
        zero = jnp.asarray(0., point.real.dtype)
        return jax.lax.fori_loop(0, d, body, zero)
    
    lap_re = laplacian(lambda r: jnp.real(wavefn(r)))
    lap_im = laplacian(lambda r: jnp.imag(wavefn(r)))

    return jax.lax.complex(lap_re, lap_im)

In [8]:
lap_psi = wfn_laplacian(wavefn, R)
lap_psi.dtype, lap_psi.shape

(dtype('complex128'), ())

In [9]:
def laplacian_fd(
    wavefn: eqx.Module,
    point,
    h: float = 1e-6
):
    d   = point.size
    f0  = wavefn(point)
    lap = 0.0 + 0.0j

    # 파이썬 for-loop는 작은 d(≲1e3)에서 문제없음.
    for i in range(d):
        v = jnp.zeros(d, point.dtype).at[i].set(h).reshape(point.shape)
        fp = wavefn(point + v)
        fm = wavefn(point - v)
        lap += (fp - 2 * f0 + fm) / (h * h)
    return lap

In [18]:
subkey, key = jax.random.split(key)
R = jax.random.normal(subkey, (8, 2))

lap_auto = wfn_laplacian(wavefn, R)
lap_num  = laplacian_fd(wavefn, R, h=1e-4)

print("autodiff Δψ =", lap_auto)
print("finite-diff Δψ =", lap_num)
print("difference =", jnp.abs(lap_auto - lap_num))

autodiff Δψ = (0.0002157376349394674-0.0006800012598117042j)
finite-diff Δψ = (0.00021574665170435558-0.0006799995963991008j)
difference = 9.168914359825699e-09
