The first step, let's import all the libraries

In [None]:
!pip install autofd &> /dev/null
!pip install jax_xc &> /dev/null

In [None]:
import sys
sys.path.append("../")
import jax
from jax import lax
import numpy as np
import jax.numpy as jnp
import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt
from functools import partial

from ase.dft.kpoints import monkhorst_pack
from jrystal import crystal, energy, wave, occupation
from jrystal._src.grid import g_vectors, r_vectors, half_frequency_mask
from jrystal._src.bloch import bloch_wave
jax.config.update("jax_enable_x64", True)

# Define the Crystal Structure

We then create a crystal structure. The crystal structure looks like something below, it is made of repeatively tiled cells (We only show a 3x3 tile in 2D space, in reality, we consider 3D cells tiled infinitely). In each cell, there are atoms. Electrons of the atoms are described with wave functions in the 3D space.

![](./crystal.svg)

The variables that we need for calculating the electron wave functions include 
- The geometry of the unit cell, unit cells are parallelpipeds, they can be described by the three basis vectors.
- The atoms' positions within the unit cell, and the type of the atoms.

In [None]:
a = 3.5667  # angstrom
cell_vectors = np.array(
  [(0, a / 2, a / 2), (a / 2, 0, a / 2), (a / 2, a / 2, 0)]
)
symbols = 'C2'
positions = np.array([(0, 0, 0), (a / 2, a / 2, a / 2)])
diamond = crystal.Crystal(
  symbols=symbols,
  positions=positions,
  cell_vectors=cell_vectors,
)

# Grid Sampling and Fourier Transformation

Without going into physics, let's state something general about periodic functions living in the cells.
If we sample the value of a periodic functions $f$ on the red grid in the real space unit cell as shown below, we can perform Discrete Fourier Transformation to get the value of $F$ at the corresponding grid in the reciprocal (frequency) space.

Let

$$\boldsymbol{r} = \frac{n_1}{N_1}\boldsymbol{a_1} + \frac{n_2}{N_2}\boldsymbol{a_2} + \frac{n_3}{N_3}\boldsymbol{a_3}$$ 

$N_1, N_2, N_3$ are the number of points sampled along each of the cell vector, and $n_i\in[0 .. N_i)$. This grid of points are illustrated as the red dots in the real space.

The corresponding grid in the reciprocal space are

$$\boldsymbol{G} = {n_1}\boldsymbol{b_1} + {n_2}\boldsymbol{b_2} + {n_3}\boldsymbol{b_3}$$

Where $b_3=2\pi (\boldsymbol{a_1}\times \boldsymbol{a_2})$. They are illustrated as the red dots in the reciprocal space. They represent the frequency components that we're able to capture about the function $f$ with the real space grid that we sample. If we sample more densely in the real space, correspondingly we get a more dense grid in the reciprocal space too.

Correspondingly, we can also sample $F$ within the reciprocal unit cell, as shown in green dots. If we perform inverse Discrete Fourier Transformation, we get the value of $f$ at the green grid in the real space. If the green dots are more dense, it gives us information of a larger region in the real space.

![](./grids.svg)

# Electron wave functions

Given the above crystal structure, we would like to calculate the electron wave function. According to Bloch theorem, the wave functions of electrons can be parameterized as the following form. 

$$\psi_{ik}=e^{-i\boldsymbol{k}\boldsymbol{r}}\sum_G c_{ikG} e^{i\boldsymbol{G}\boldsymbol{r}}$$

The $G$ are the red grid point in the reciprocal space. More $G$ points means denser red dots in the real space, therefore a higher resolution (higher approximation capability) within the real space unit cell.

The $k$ are the green grid points in the reciprocal space. As seen from the figure, more k points means we cover more cells in the real space. Since there are electrons in each cells, the more cells we cover, the more electrons we need to describe. Therefore the electron wave function $\psi_{ik}$ is indexed by $i$ and $k$, where $k$ is the index of the cell and $i$ is the index of the electron wthin the cell.

The $c$ are flexible parameters that can be learned to minimize the energy of the entire system.

In [None]:
k_grid_sizes = [1, 1, 1]  # for simplicity we sample just one k point
g_grid_sizes = [10, 10, 10]  # we sample 10 x 10 x 10 G grid.
num_k = np.prod(k_grid_sizes)
num_g = np.prod(g_grid_sizes)

# we use the following APIs to sample the grid of k and G points.
k_vector_grid = monkhorst_pack(k_grid_sizes)
g_vector_grid = g_vectors(cell_vectors, g_grid_sizes)
# the number of electrons per unit cell.
num_electrons = diamond.num_electrons
# each electron has two spins, up/down.
num_spins = 2

def volume(cell_vectors):
  return jnp.linalg.det(cell_vectors)

vol = volume(cell_vectors)

In [None]:
# Now let's build a wave function that depends on the parameter
from jrystal import wave

def electron_wave_function(parameter, r, force_fft=False):
  # jrystal provide the bloch_wave api to construct a bloch wave function
  # given the G points, k points.
  psi = bloch_wave(cell_vectors, parameter, k_vector_grid)
  # Notice that the bloch wave is not normalized on the unit cell, 
  # but we would like to normalize the electron wave function on the unit cell.
  # if a batch of r of shape (batch..., 3) is passed to bloch_wave
  # this function will return the batch dimensions at the last
  # aka (num_spins, num_electrons, num_k_points, batch...).
  return psi(r, force_fft=force_fft) / jnp.sqrt(vol)


# Verifying the correctness of wave function

The entire DFT computation is based on the assumption that the single electron wave functions are orthogonal to each other.
The `electron_wave_function` we define above can take a batch of `parameter`, and parallelly compute the output for a batch
of wave functions. In order to make the batch of wave functions orthogonal to each other, we need to constrain the 
batch of `parameter` to be orthogonal to each other.

Now let's verify that when we pass a batch of `parameter` that are orthogonal to each other,
the resulting wave functions are indeed orthogonal. More specifically, 
for wave function $\psi(c, r)$ where $c$ is the parameter. 
We sample a batch of $C=\{c_0, c_1, \cdots\}$, where $c_i^H c_j=\delta_{ij}$.
We want to verify that

$$\int\psi(c_i,r)^*\psi(c_j,r) dr=\delta_{ij}$$


In [None]:
# 1. create 32 random orthonormal parameter
# qr here ensures the random_param is orthonormal
random_param, _ = jnp.linalg.qr(
  jax.random.normal(
    jax.random.PRNGKey(0),
    (np.prod(g_grid_sizes), 32),
  ),
  mode="reduced",
)
random_param = jnp.reshape(random_param.T, (32, *g_grid_sizes))

# 2. we create the function that is able to compute the integration
# within the unit cell given the integrand.
def integrate_within_cell(integrand):
  r_vector_grid = r_vectors(cell_vectors, g_grid_sizes)
  out = integrand(r_vector_grid)
  return jnp.mean(out) * vol

# 3. given two parameters param_a, and param_b, we compute the
# overlap integral of wave functions defined by these two parameters.
def overlap_ab(param_a, param_b):
  return integrate_within_cell(
    lambda r: jnp.real(
      jnp.conj(
        electron_wave_function(param_a, r, force_fft=True)
      ) * electron_wave_function(param_b, r, force_fft=True)
    )
  )

# 4. for a batched param, we compute the overlap of every pairs of parameters.
def overlap(param):
  return jax.vmap(
    jax.vmap(
      overlap_ab,
      in_axes=(None, 0),
    ),
    in_axes=(0, None),
  )(param, param)

# 5. we expect the overlap between two wave functions defined by different parameters
# to be 0. And the overlap of the wave function with itself to be 1.
psi_l2_norm = overlap(random_param)
target = jnp.eye(psi_l2_norm.shape[0])
print("Max error is: ", jnp.abs(psi_l2_norm - target).max())

# Occupation and its constraint

Bloch theorem tells us that the eigenstates of the system follows the form of 

$$\psi_{ik}(r)=\exp{-\mathrm{i}kr}\sum_Gc_{ikG}\exp{\mathrm{i}Gr}$$

With a corresponding eigenvalue $\epsilon_{ik}$, which is the energy of the wave function.
We always occupy the lowest energies, therefore, if we consider $N$ electrons, 
they always take $N$ indices that has the lowest energy.

However, when we perform direct minimization, the wave functions we obtain are linear mixtures of the eigenstates.
Every wave function we obtain can have a certain fraction of the lowest eigenstates. 
Therefore, the occupation is no longer in an all or none style even under zero temperature.
Instead, the new rules for occupation is that every unique $\{i,k\}$ pair can be occupied. There are two constraints

1. Each of them can fit at most one electron.
2. The total occupation should add up to $N$.

As the above two contraints are both convex, we can perform the "Projection onto convex sets" algorithm, 
where we iteratively project our current occupation parameter which may not be in the feasible set to 
the feasible set of each of the constraints.

In [None]:
def normalize(raw_occ):
  raw_occ += (num_electrons - jnp.sum(raw_occ.flatten())) / np.prod(raw_occ.shape)
  return raw_occ

def clip(raw_occ):
  return jnp.clip(raw_occ, a_min=0, a_max=1 / num_k)

def constrain_occ(raw_occ):
  cond_fn = lambda x: jnp.logical_not(jnp.array_equal(clip(x), x))
  body_fn = lambda x: normalize(clip(x))  # iteratively project
  return lax.while_loop(cond_fn, body_fn, raw_occ)

# Half frequency constraint for the parameter

Recall that the wave function has the form

$$\psi_{ik}=e^{-i\boldsymbol{k}\boldsymbol{r}}\sum_G c_{ikG} e^{i\boldsymbol{G}\boldsymbol{r}}$$

It contains all the frequency components in the G grid.
However, when we compute the density function via

$$\rho(r)=\sum_i \psi^*(\theta_i,r)\psi(\theta_i,r)$$

There is a square of wave function in the equation, therefore, the density function contains double
of the frequency components. If the maximum frequency component in a wave function is $G$, 
the maximum for density function will be $2G$. In order to use the same grid for fourier transformation
of both the wave function and the density function, we explicitly request the wave functions to only take
the smaller half of all the frequency components. We provide a convenient function to mask out the 
frequencies for the wave function.

In [None]:
mask = half_frequency_mask(g_grid_sizes)
mask_size = int(jnp.sum(mask))

## Visualizing the frequency mask

In [None]:
# show this mask
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection="3d")
colors = np.empty(g_grid_sizes, dtype=object)
colors[mask] = "lightblue"
ax.voxels(mask, facecolors=colors, linewidth=0.5, edgecolor="black")
ax.set_aspect('equal')

The above mask may seem unintuitive, this is because in FFT the frequency domain is reorganized from
`[-3, -2, -1, 0, 1, 2, 3]` into `[0, 1, 2, 3, -3, -2, -1]`,
causing the lower frequency half to be on the sides, in 3D they are at the corners.

# Enforcing both orthonormal constraint and half frequency constraints

1. Orthogonal to each other.
2. Masked to have only lower half of the frequencies.

Now we can put both constraints together, and create a function that takes the raw parameter and 
reparameterized it into a legit parameter that satisfies both of the above.

Remember that our parameter specifies the linear coefficient for all the frequency components, 
therefore when we talk about a single parameter, it has the shape `g_grid_sizes` which is the
size of the grid in the frequency space. As we need 

1. one wave function for each electron, 
2. independently parameterize wave function for different $k$,
3. independently parameterize wave function for different electron spin,

The shape of our batched parameter becomes `(num_spins, num_k, num_electrons, *g_grid_sizes)`.

In [None]:
param_shape = (num_spins, num_k, num_electrons, *g_grid_sizes)

# We want the orthonormal constraint,
# < c[s, k, i], c[s, k, j] > = delta_{i,j}
# any wave function has an unit l2 norm.
# any two electron wave functions with same s and k needs to be orthogonal to each other.

# For convenience, we can create a function like below to transform our raw parameter
# to constrained parameter.
# If we get a raw_param of shape (num_spins, num_k, num_electrons, *g_grid_sizes)
def constrain_param_1(raw_param):
  # keep only the masked region
  masked_param = raw_param[..., mask]
  # orthogonalize the masked region
  masked_orthogonal_param = jnp.swapaxes(
    jnp.linalg.qr(
      jnp.swapaxes(masked_parameter, -1, -2), mode="reduced"
    )[0], 
    -1, 
    -2,
  )
  # set the unmasked region to zero, mask region to the orthogonalized
  # values.
  return jnp.zeros_like(raw_param, dtype=jnp.complex128).at[..., mask].set(
    masked_orthogonal_param,
  )

# We can further reduce the size of raw_param as it only needs to encode the nonzero entries.
# where raw_param has shape (num_spins, num_k, mask_size, num_electrons)
# We save some computation in this way.

raw_param_shape = (num_spins, num_k, mask_size, num_electrons)

def constrain_param(raw_param):
  raw_param = jnp.swapaxes(
    jnp.linalg.qr(raw_param, mode="reduced")[0], 
    -1, -2,
  )
  return jnp.zeros(param_shape, dtype=jnp.complex128).at[..., mask].set(
    raw_param,
  )

# Testing the orthonormal constraint

Let's now test whether `constrain_param` generates the correct parameters.
Aka, check whether the resulting wave functions are orthonormal with the `overlap` function defined above.

In [None]:
raw_param = jax.random.normal(
  key=jax.random.PRNGKey(0),
  shape=raw_param_shape,
  dtype=jnp.complex128,
) * 0.01
param = constrain_param(raw_param)  # (num_spins, num_k, num_electrons, *g_grid_sizes)
param = jnp.reshape(param, (-1, *param.shape[2:]))

out = jax.vmap(overlap)(param)
target = jnp.eye(out.shape[-1])
print("Max error is: ", jnp.abs(out - target).max())

# Testing orthonormal and occupation constraints

We know that when we integrate the density function over the unit cell, 
the result will be the number of electrons (number of charges) in that unit cell.

## Create density function from wave function

The electron density of the system depends on two things,
1. The orthogonal set of electron wave functions.
2. The occupation of each wave function.

$$
\rho(r)=\sum_i \psi^*(c_i,r)\psi(c_i,r)
$$

Therefore, we create the density function that takes `param` and `occ` arguments. 
`param` is required to satisfy the orthogonal constraint, 
and `occ` is required to satisfy the above mentioned two constraints.

In [None]:
# With occupation and wave function, we can create a density function

def density(param, occ, r, force_fft=False):
  psi = electron_wave_function(param, r, force_fft=force_fft)
  dens = jnp.real(jnp.conj(psi) * psi)
  return jnp.einsum("ski...,ski->s...", dens, occ)
    

## Create param and occ that satisfy constraints

In [None]:
raw_param = jax.random.normal(
  key=jax.random.PRNGKey(0),
  shape=raw_param_shape,
  dtype=jnp.complex128,
) * 0.01
param = constrain_param(raw_param)

raw_occ = jax.random.normal(
  key=jax.random.PRNGKey(1),
  shape=param_shape[:3],
  dtype=jnp.float64,
)
occ = constrain_occ(raw_occ)

# get a scalar density by summing over spins.
rho = lambda r: jnp.sum(density(param, occ, r, force_fft=True), axis=0)

N = integrate_within_cell(rho)
print(f"Number of electrons per cell is {num_electrons}, "
      f"while density function integrates to {N}")

# Compute total energy given parameter and occupation

## Compute Kinetic, Hartree, External energies

## Compute exchange correlation energy with JAX-XC

# Optimize total energy

In a larger scope, we're building the following computational graph

`nonzero_param` $\overset{\text{orthogonalize}}{\longrightarrow}$ `nonzero_orthogonal_param` $\overset{\text{embed in the grid}}{\longrightarrow}$ `parameter` $\longrightarrow$ `wave_function` $\longrightarrow$ `energy`

What we need is the optimal wave function that minimizes the energy, however, we perform the minimization in the space of the raw parameter, we compute the gradient of the energy function with respect to the `nonzero_param`, so as to minimize the energy.

In [None]:
# Let's now compute the energy,
import autofd
import jax_xc
from jrystal import energy
from jaxtyping import Float64, Array

def diamond_energy(raw_param, raw_occ):
  param = constrain_param(raw_param)
  occ = constrain_occ(raw_occ)
  r_vector_grid = r_vectors(cell_vectors, g_grid_sizes)
  # because we pass the r_vector_grid, which can be accelerated by fft
  density_grid = density(param, occ, r_vector_grid, force_fft=True)
  assert density_grid.shape == (num_spins, *g_grid_sizes)
  reciprocal_density_grid = jnp.fft.fftn(density_grid.sum(0), axes=(-3, -2, -1))

  # TODO: All the apis here we call are directly
  # implemented based on the reciprocal space formulation.
  # We would like to have the pure functional version, where the functionals
  # take only the wave functions as input.
  # 
  # E.g. in autofd api, something like the following,
  # 
  # psi.grid = r_vector_grid
  # rho = o.compose(jnp.sum, psi**2)
  # kinetic = energy.kinetic(psi)
  # hartree = energy.hartree(rho)
  
  kinetic = energy.kinetic(
    g_vector_grid,
    k_vector_grid,
    param,
  )
  hartree = energy.hartree(
    reciprocal_density_grid,
    g_vector_grid,
    vol,
  )
  external = energy.external(
    reciprocal_density_grid,
    diamond.positions,
    diamond.charges,
    g_vector_grid,
    vol,
  )

  r = r_vector_grid.reshape((-1, 3))

  @autofd.function
  def dens(r: Float64[Array, "3"]) -> Float64[Array, "2"]:
    return density(param, occ, r, force_fft=True)

  with jax.ensure_compile_time_eval():
    epsilon_xc = jax_xc.experimental.gga_x_pbe(dens)
    
  xc_grid = jax.vmap(
    lambda r: epsilon_xc(r) * dens(r).sum(),
    in_axes=0,
    out_axes=-1,
  )(r)
  exc = xc_grid.sum() / vol / num_g    
  return (kinetic, hartree, external, exc)

raw_param = jax.random.normal(
  key=jax.random.PRNGKey(0),
  shape=raw_param_shape,
  dtype=jnp.complex128,
) * 0.01

raw_occ = jax.random.normal(
  key=jax.random.PRNGKey(1), 
  shape=param_shape[:3]
) * 0.01

jax.jit(diamond_energy)(raw_param, raw_occ)

