In [3]:
import sys
sys.path.append("../")
import jax
import jax.numpy as jnp
import numpy as np
from jrystal import crystal, energy, wave, occupation
from jrystal._src.grid import r_vectors
from jrystal._src.wave import ElectronWave
from jrystal._src.operators import (
  hartree_potential,
  external_potential,
)
import autofd.operators as o
from autofd import function, Grid
import jax_xc
from jaxtyping import Float64, Complex128, Array
jax.config.update("jax_enable_x64", True)

diamond = crystal.Crystal(xyz_file="../geometries/diamond1.xyz")
cell_vectors = diamond.cell_vectors
num_elec = diamond.num_electrons
k_grid_sizes = [1, 1, 1]
g_grid_sizes = [32, 32, 32]


def potential_energy(density2, cell_vectors):
  density = o.numpy.sum(density2)
  vhartree = hartree_potential(
    density,
    cell_vectors=cell_vectors,
    grid_sizes=g_grid_sizes,
  )
  vext = external_potential(
    positions=diamond.positions,
    charges=diamond.charges,
    cell_vectors=cell_vectors,
    grid_sizes=g_grid_sizes,
  )
  Ehartree = 0.5 * o.braket(vhartree, density)
  Eext = o.braket(vext, density)
  epsilon_xc = jax_xc.experimental.lda_x(density2)
  Exc = o.braket(epsilon_xc, density)
  return Ehartree + Eext + Exc


wave = ElectronWave(num_elec, g_grid_sizes, k_grid_sizes)
key = jax.random.PRNGKey(42)
variables = wave.init(key, cell_vectors, method="kinetic_energy")

def energy(variables, cell_vectors):
  kinetic = wave.apply(variables, cell_vectors, method="kinetic_energy")

  @function
  def density2(r: Float64[Array, "3"]) -> Float64[Array, "2"]:
    return wave.apply(variables, r, cell_vectors, method="density")

  r_vector_grid = r_vectors(cell_vectors, g_grid_sizes)
  r_vec = r_vector_grid.reshape((-1, 3))
  vol = jnp.linalg.det(cell_vectors)
  density2.grid = Grid(nodes=(r_vec,), weights=(vol/np.prod(g_grid_sizes),))
  
  return kinetic + potential_energy(density2, cell_vectors)

e = energy(variables, cell_vectors)
print(e)

1.6902062758448961
1.6861259777961273
