# e3ferminet
E(3)-equivariant neural network ansatz for atomic and molecular VMC calculations

#Background

In computational chemistry, solving the ground state wavefunction of a molecule allows us to predict its properties accurately.
However, an exact solution to the wavefunction is computationally infeasible for these quantum many-body systems,
and various approximation methods have been proposed and studied over the last few decades.
A classical method is the variational Monte Carlo method (VMC).
VMC is a variant of the variational method, which identifies the ground state energy $E_0$ as the minimum of the energy functional
\begin{equation}
    E[\psi] = \frac{\langle \psi| \hat H |\psi\rangle}{\langle \psi | \psi \rangle} = \frac{\int dx\, |\psi(x)|^2 \frac{\hat H\psi(x)}{\psi(x)}}{\int dx\, |\psi(x)|^2},
\end{equation}

The value of the energy functional is lower-bounded by the ground state energy, so the variational method seeks to find better and better approximations for $E_0$ by guessing and refining an ansatz. The energy functional can then be computed by sampling points from $p(x) \propto |\psi(x)|^2$ to estimate the expectation value of $\frac{\hat H\psi(x)}{\psi(x)}$. The main difficulty remains in finding a good ansatz.

One standard choice in VMC is to guess a wavefunction of the Slater-Jastrow type. This consists of a Slater determinant multiplied by a Jastrow factor (typically of the form $e^{J}$, with $J= \sum_{ij} u_{ij}(|r_i - r_j|)$, i.e. a function of all pairwise distances to account for electron correlation. For example, the wavefunction ansatz PauliNet developed by Noe et al. uses a preliminary Hartree-Fock calculation as an input and expressiveness comes from the Jastrow factor and a backflow transformation, both represented as DNNs. However, one benefit of neural networks is that we don't need to be restricted to wavefunctions of this form, which have limitations such as being constrained to a finite basis set.

In [this](https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.2.033429) paper, the authors used an ansatz parameterized by a neural network, and by performing gradient descent with $E[\psi]$ as the loss function, they were able to accurately recover the ground state wavefunctions of a few small but challenging molecules. In this project, we will explore the potential of using an $E(3)$-equivariant neural network to parameterize the wavefunction. We will start by investigating single atoms. 


## Single atom wavefunction
Let's first study a single atom of atomic number $Z$. For $n$ electrons, the Hamiltonian is
$$
\hat H = -\frac{1}{2} \sum_i \nabla_i^2 - \sum_i \frac{Z}{r_i} + \sum_{i<j} \frac{1}{|\mathbf r_i - \mathbf r_j|}.
$$
We'll parameterize the multi-electron wavefunction $\psi(\mathbf r_1, \ldots, \mathbf r_n)$ with a neural network $\phi_\theta(\mathbf r_1, \ldots, \mathbf r_n)$ that is $SO(3)$-equivariant in each input $\mathbf r_i$, where we obtain $\psi$ by antisymmetrizing $\phi_\theta$:
$$
\psi(\mathbf r_i, \ldots, \mathbf r_n) = \sum_{\sigma \in S(n_\uparrow) \times S(n_\downarrow)} \mathrm{sgn}(\sigma) \phi_\theta(\mathbf r_{\sigma(1)}, \ldots, \mathbf r_{\sigma(n)}).
$$
The sum is over all products of permutations of $n_\uparrow$ spin-up electrons with permutations of $n_\downarrow$ spin-down electrons, where $n_\uparrow + n_\downarrow = n$, where WLOG we label the spin-up electrons as $1, \ldots, n_\uparrow$ and the spin-down electrons as $n_\uparrow + 1, \ldots, n$.

For computing the expected energy in the state $\psi$, we will for now use Monte Carlo integration since we don't yet have an efficient way of sampling from the distribution $p(X) \propto |\psi(X)|^2$. In other words, we sample many points $X_i = (\mathbf r_{i1}, \ldots, \mathbf r_{in})$ where each $\mathbf r_{ij}$ is chosen uniformly and independently from a ball of a certain radius $r_{max}$. We then approximate
$$
E[\psi] \approx \frac{\sum_i \psi(X_i)^* \hat H\psi(X_i)}{\sum_i \psi(X_i)^* \psi(X_i)}.
$$
To implement the boundary condition $\psi \to 0$ as $|X| \to \infty$, we will add a regularization term that penzalizes large values of $|\psi|$ near $r = r_{\max}$ by adding a regularization term $\lambda \sum_i |\frac{r_i}{r_{max}}|^\beta$ for some large exponent $\beta$ to the Hamiltonian.

Finally, we perform gradient descent with
$$
\theta \gets \theta - \alpha \nabla_\theta E[\psi]
$$
to find the ground state.

In [3]:
%%capture
!pip install e3nn-jax

In [4]:
import numpy as np
from tqdm import tqdm
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, jacobian
import flax
import optax
import pandas as pd
import matplotlib.pyplot as plt
import plotly
import plotly.express as px

import e3nn_jax as e3nn  # import e3nn-jax

jnp.set_printoptions(precision=4, suppress=True)

print(jax.__version__)
print(flax.__version__)
print(optax.__version__)
print(e3nn.__version__)
print(jnp.ones(()).device())



0.4.8
0.6.9
0.1.5
0.17.4
TFRT_CPU_0


### Hydrogen

As a proof of concept, we'll first try predicting the hydrogen atom ground state wavefunction. We'll use a simple MLP with a single hidden layer with 5 neurons, where the input is $|\mathbf r|^2$.

### Helium

Now, let's try the helium atom, which is the most electrons we can have before we have to introduce antisymmetrization. We'll use a simple MLP with two hidden layers with 5 neurons each, where the inputs to the MLP are scalars $|\mathbf r_1|^2$, $|\mathbf r_2|^2$, and $\mathbf r_1 \cdot \mathbf r_2$.

In [8]:
class ToyAnsatz:
    def __init__(self, Z, N_up, N_down, config):
        self.Z = Z
        self.N_up = N_up
        self.N_down = N_down

        self.mlp = e3nn.flax.MultiLayerPerceptron([5, 5, 5, 1], act=jax.nn.gelu, output_activation=jax.nn.sigmoid)
        self.envelope = lambda zeta, coords: jnp.exp(-zeta * jnp.linalg.norm(coords, axis=-1))

        @jit
        def wavefunction(w, coords):  # coords can be unbatched or batched
            # TO-DO antisymmetrize the wavefunction at the end (spin-up and spin-down separately)
            x = e3nn.tensor_square(e3nn.IrrepsArray(f"{self.Z}x1o", coords)).filter(keep="0e")
            return self.mlp.apply(w["mlp"], x.array).squeeze(-1) * self.envelope(jnp.abs(w["envelope"]), coords)
        self.wavefunction = wavefunction

    def init_weights(self, random_key):  # coords can be batched or unbatched
        subkey1, subkey2 = jax.random.split(random_key)
        coords = jnp.empty((3 * self.Z,))
        x = e3nn.tensor_square(e3nn.IrrepsArray(f"{self.Z}x1o", coords)).filter(keep="0e")
        return {
            "mlp": self.mlp.init(subkey1, x),
            "envelope": jnp.sqrt(jax.random.chisquare(subkey2, df=2)) * self.Z
        }

class ManualAnsatz:
    def __init__(self, Z, N_up, N_down, config):
        self.Z = Z
        self.N_up = N_up
        self.N_down = N_down
        self.N = N_up + N_down

        self.hidden_irreps = "5x0e+5x1e+5x1o+5x2e+5x2o"
        self.hidden_irreps_before_gate = "25x0e+5x1e+5x1o+5x2e+5x2o"
        self.lmax = 2
        # assert self.lmax + 1 == len(self.hidden_channels)

        self.tensor = lambda input: e3nn.tensor_square(input).filter(keep=e3nn.Irrep.iterator(lmax=self.lmax))
        self.linear1 = e3nn.flax.Linear(irreps_out=self.hidden_irreps_before_gate, biases=True)
        self.linear2 = e3nn.flax.Linear(irreps_out=self.hidden_irreps_before_gate, biases=True)
        self.linear_head = e3nn.flax.Linear(irreps_out="0e", biases=True)
        self.envelope = lambda zeta, coords: jnp.exp(-zeta * jnp.linalg.norm(coords, axis=-1))

        @jit
        def wavefunction(w, coords):  # coords can be unbatched or batched
            coords_irreps = e3nn.IrrepsArray(f"{self.N}x1o", coords)
            x = self.tensor(coords_irreps)
            x = self.linear1.apply(w["linear1"], x)
            x = e3nn.gate(x)
            x = self.linear2.apply(w["linear2"], x)
            x = e3nn.gate(x)
            x = x.filter(keep="0e")
            x = self.linear_head.apply(w["linear_head"], x).array.squeeze(-1)
            return x * self.envelope(jnp.abs(w["envelope"]), coords)
        self.wavefunction = wavefunction

    def init_weights(self, random_key):  # coords can be batched or unbatched
        w = {}
        subkey1, subkey2, subkey3, subkey4 = jax.random.split(random_key, num=4)
        coords = jnp.empty((3 * self.Z,))
        coords_irreps = e3nn.IrrepsArray(f"{self.N}x1o", coords)
        x = self.tensor(coords_irreps)
        w["linear1"] = self.linear1.init(subkey1, x)
        x = self.linear1.apply(w["linear1"], x)
        print(x.irreps)
        x = e3nn.gate(x)
        w["linear2"] = self.linear2.init(subkey2, x)
        x = self.linear2.apply(w["linear2"], x)
        x = e3nn.gate(x)
        x = x.filter(keep="0e")
        w["linear_head"] = self.linear_head.init(subkey3, x)
        x = self.linear_head.apply(w["linear_head"], x).array.squeeze(-1)
        w["envelope"] = jnp.sqrt(jax.random.chisquare(subkey4, df=2)) * self.Z
        return w

class FerminetAnsatz:
    def __init__(self, Z, N_up, N_down, config):
        self.Z = Z
        self.N_up = N_up
        self.N_down = N_down
        self.N = N_up + N_down

        # self.hidden_irreps = "5x0e+5x1e+5x1o+5x2e+5x2o"
        # self.hidden_irreps_before_gate = "25x0e+5x1e+5x1o+5x2e+5x2o"
        self.lmax = config.get("lmax", )

        self.tensor = lambda input: e3nn.tensor_square(input).filter(keep=e3nn.Irrep.iterator(lmax=self.lmax))
        self.linear1 = e3nn.flax.Linear(irreps_out=self.hidden_irreps_before_gate, biases=True)
        self.linear2 = e3nn.flax.Linear(irreps_out=self.hidden_irreps_before_gate, biases=True)
        self.linear_head = e3nn.flax.Linear(irreps_out="0e", biases=True)
        self.envelope = lambda zeta, coords: jnp.exp(-zeta * jnp.linalg.norm(coords, axis=-1))

        @jit
        def wavefunction(w, coords):  # coords can be unbatched or batched
            coords_irreps = e3nn.IrrepsArray(f"{self.N}x1o", coords)
            x = self.tensor(coords_irreps)
            x = self.linear1.apply(w["linear1"], x)
            x = e3nn.gate(x)
            x = self.linear2.apply(w["linear2"], x)
            x = e3nn.gate(x)
            x = x.filter(keep="0e")
            x = self.linear_head.apply(w["linear_head"], x).array.squeeze(-1)
            return x * self.envelope(jnp.abs(w["envelope"]), coords)
        self.wavefunction = wavefunction

    def init_weights(self, random_key):  # coords can be batched or unbatched
        w = {}
        subkey1, subkey2, subkey3, subkey4 = jax.random.split(random_key, num=4)
        coords = jnp.empty((3 * self.Z,))
        coords_irreps = e3nn.IrrepsArray(f"{self.N}x1o", coords)
        x = self.tensor(coords_irreps)
        w["linear1"] = self.linear1.init(subkey1, x)
        x = self.linear1.apply(w["linear1"], x)
        print(x.irreps)
        x = e3nn.gate(x)
        w["linear2"] = self.linear2.init(subkey2, x)
        x = self.linear2.apply(w["linear2"], x)
        x = e3nn.gate(x)
        x = x.filter(keep="0e")
        w["linear_head"] = self.linear_head.init(subkey3, x)
        x = self.linear_head.apply(w["linear_head"], x).array.squeeze(-1)
        w["envelope"] = jnp.sqrt(jax.random.chisquare(subkey4, df=2)) * self.Z
        return w

In [9]:
class E3FerminetAtom:
    def __init__(self, config):
        self.Z = config.get("Z", 1)
        self.N_up = config.get("N_up", 1)
        self.N_down = config.get("N_down", 1)
        self.sampler = config.get("sampler")  # use M-H if None
        self.sampling_dist = config.get("sampling_dist")  # use M-H if None
        assert((self.sampler is None) == (self.sampling_dist is None))
        self.N_samples = config.get("batch_size", 20000)
        self.num_batches = config.get("num_batches", 1000)
        self.lr = config.get("lr", 0.1)
        self.validate_every = config.get("validate_every", 2000)
        self.moving_avg_coeff = config.get("moving_avg_coeff", 0.1)
        self.regularize = "regularize" in config
        self.regularize_pow = config["regularize"].get("pow", 8) if self.regularize else None
        self.regularize_coeff = config["regularize"].get("coeff", 100) if self.regularize else None
        self.regularize_max_r = config["regularize"].get("max_r", 2) if self.regularize else None
        self.patience = config.get("patience", 200)
        self.random_key = jax.random.PRNGKey(config.get("random_seed", 0))

        self.ansatz = ManualAnsatz(self.Z, self.N_up, self.N_down, config["ansatz"])

        self.w = None
        self.w_list = None
        self.energy_moving_avgs = None

        @jit
        def local_kinetic_energy(w, coords):  # coords must be unbatched
            def laplacian(coords):
                return jnp.einsum('ii->', jacobian(jacobian(self.ansatz.wavefunction, argnums=1), argnums=1)(w, coords))
            return -0.5 * laplacian(coords) / self.ansatz.wavefunction(w, coords)
        self._local_kinetic_energy = local_kinetic_energy

        @jit
        def local_potential_energy(w, coords):  # coords must be unbatched
            coords = coords.reshape((-1, 3))
            V_e_p = -self.Z * jnp.sum(1 / jnp.linalg.norm(coords, axis=1), axis=0)
            relative_dists = jnp.linalg.norm(jnp.expand_dims(coords, axis=0) - jnp.expand_dims(coords, axis=1), axis=2)
            V_e_e = jnp.sum(1 / jnp.where(relative_dists == 0, np.inf, relative_dists)) / 2
            return V_e_p + V_e_e
        self._local_potential_energy = local_potential_energy

        @jit
        def local_energy(w, coords):  # coords must be unbatched
            return local_kinetic_energy(w, coords) + local_potential_energy(w, coords)
        self._local_energy = local_energy

        @jit
        def energy(w, coords_batch):
            # If sampling_dist is None, assume sampling from wavefunction
            local_energies = vmap(local_energy, in_axes=(None, 0))(w, coords_batch)
            if self.sampling_dist is None:
                return jnp.mean(local_energies)
            psi = self.ansatz.wavefunction(w, coords_batch)
            scaled_probs = psi ** 2 / vmap(self.sampling_dist)(coords_batch)
            return jnp.dot(scaled_probs, local_energies) / jnp.sum(scaled_probs)
        self._energy = energy

        if self.regularize:
            @jit
            def regularized_energy(w, coords_batch):
                reshaped_coords_batch = coords_batch.reshape((coords_batch.shape[0], -1, 3))
                penalty = jnp.sum((jnp.linalg.norm(reshaped_coords_batch, axis=2) / self.regularize_max_r) ** self.regularize_pow, axis=1)
                if self.sampling_dist is None:
                    cum_penalty = jnp.mean(penalty)
                else:
                    psi = self.ansatz.wavefunction(w, coords_batch)
                    scaled_probs = psi ** 2 / vmap(self.sampling_dist)(coords_batch)
                    cum_penalty = jnp.dot(scaled_probs, penalty) / jnp.sum(scaled_probs)
                return energy(w, coords_batch) + self.regularize_coeff * cum_penalty
            self._regularized_energy = regularized_energy
        
        if self.sampler is None:
            self.MH_stdev = config["MH"].get("stdev", 0.2)
            self.MH_warmup = config["MH"].get("warmup", 500)
            self.MH_interval = config["MH"].get("interval", 10)
            self.MH_batch_size = config["MH"].get("batch_size", 64)
            self.sampled_coords = None
            def sampler(random_key, Z, num_samples):
                # returns jnp array of shape (num_samples, 3*Z) sampled from the wavefunction
                if self.sampled_coords is None:
                    warmup = self.MH_warmup
                    random_key, subkey = jax.random.split(random_key)
                    self.sampled_coords = self.MH_stdev * jax.random.normal(subkey, (self.MH_batch_size, 3*Z))
                else:
                    warmup = self.MH_interval
                coords = []
                num_iters = warmup + (num_samples - 1) // self.MH_batch_size + 1
                num_coords_remaining = num_samples
                for i in range(num_iters):
                    random_key, subkey = jax.random.split(random_key)
                    proposal_coords = self.sampled_coords + self.MH_stdev * jax.random.normal(subkey, (self.MH_batch_size, 3*Z))
                    acceptance_ratios = (self.ansatz.wavefunction(self.w, proposal_coords) / self.ansatz.wavefunction(self.w, self.sampled_coords)) ** 2
                    random_key, subkey = jax.random.split(random_key)
                    self.sampled_coords = jnp.where(np.expand_dims(jax.random.uniform(subkey, (self.MH_batch_size,)) < acceptance_ratios, axis=1),
                                                    proposal_coords,
                                                    self.sampled_coords)
                    if i >= warmup:
                        if self.MH_batch_size <= num_coords_remaining:
                            coords_to_add = self.sampled_coords
                            num_coords_remaining -= self.MH_batch_size
                        else:
                            coords_to_add = self.sampled_coords[:num_coords_remaining]
                            num_coords_remaining = 0
                        coords.append(coords_to_add)
                return jnp.concatenate(coords)
                # self.sampled_coords = jax.random.normal(random_key, (3*Z,))
                # for i in range(500): #need time to stabilize chain, tune this number
                #     random_key, subkey = jax.random.split(random_key)
                #     proposal_coords = self.sampled_coords + 0.1*jax.random.normal(subkey, (3*Z,)) #need to tune stdev
                #     a = self.ansatz.wavefunction(self.w, proposal_coords)**2 / self.ansatz.wavefunction(self.w, x_init)**2
                #     if jax.random.uniform(subkey) < a:
                #         self.sample_coords = proposal_coords
                # coords = []
                # for i in range(num_samples):
                #     random_key, subkey = jax.random.split(random_key)
                #     proposal_coords = x_init + 0.3*jax.random.normal(subkey, (3*Z,))
                #     a = self.ansatz.wavefunction(self.w, proposal_coords)**2 / self.ansatz.wavefunction(self.w, x_init)**2
                #     if jax.random.uniform(subkey) < a:
                #         x_init = proposal_coords
                #     coords.append(x_init)
                # return jax.numpy.array(coords)
            self.sampler = sampler
        else:
            assert "MH" not in config
    
    def init_weights(self):
        self.random_key, subkey = jax.random.split(self.random_key)
        self.w = self.ansatz.init_weights(subkey)
        print("WEIGHTS:", self.w)
        # print("ENERGY:", self._energy(self.w, coords_batch))
        # if self.regularize:
        #     print("REGULARIZED ENERGY:", self._regularized_energy(self.w, coords_batch))
    
    def train_loop(self):
        # Training loop

        self.init_weights()

        grad_energy = jit(grad(self._regularized_energy)) if self.regularize else jit(grad(self._energy))

        optimizer = optax.adamw(learning_rate=self.lr)
        opt_state = optimizer.init(self.w)

        weights = [self.w]
        self.random_key, subkey = jax.random.split(self.random_key)
        coords_batch = self.sampler(subkey, self.Z, self.N_samples)
        energy = self._energy(self.w, coords_batch)
        energies = [energy]
        energy_moving_avgs = [energy]
        for step in tqdm(range(self.num_batches)):
            self.random_key, subkey = jax.random.split(self.random_key)
            coords_batch = self.sampler(subkey, self.Z, self.N_samples)
            grads = grad_energy(self.w, coords_batch)
            updates, opt_state = optimizer.update(grads, opt_state, self.w)
            self.w = optax.apply_updates(self.w, updates)
            weights.append(self.w)
            energy = self._energy(self.w, coords_batch)
            energies.append(energy)
            energy_moving_avgs.append(energy_moving_avgs[-1] * (1 - self.moving_avg_coeff) + energy * self.moving_avg_coeff)
            if step % self.validate_every == 0:
                self.test()
            if self.patience is not None and step - np.argmin(energy_moving_avgs) >= self.patience:
                break
        self.w_list = weights
        self.energy_moving_avgs = energy_moving_avgs
        learning_curve_df = pd.DataFrame({"Batch index": np.arange(len(energy_moving_avgs)), "Energy": energies})
        fig = px.line(learning_curve_df, x="Batch index", y="Energy")
        fig.show()
    
    def choose_weights(self, idx):
        if idx == "best":
            idx = jnp.argmin(self.energy_moving_avgs)
            print(f"BEST INDEX: {idx}")
        elif idx == "last":
            idx = -1
        self.w = self.w_list[idx]

    def test(self, test_N_samples=50000):
        self.random_key, subkey = jax.random.split(self.random_key)
        coords_batch = self.sampler(subkey, self.Z, test_N_samples)
        print("GROUND STATE ENERGY: {:.4f}".format(self._energy(self.w, coords_batch)))
    
    def plot_one_electron_radial(self, max_r, plot_samples=5000):
        radii = jnp.linspace(0, max_r, plot_samples+1)
        coords_batch = np.hstack((np.expand_dims(radii, axis=1), np.zeros((plot_samples+1, 3*self.Z - 1))))
        psi = self.ansatz.wavefunction(self.w, coords_batch)
        x_label = "$r$"
        y_label = "$\\psi(r\\hat e_z, 0, \\ldots, 0)$"
        df = pd.DataFrame({x_label: radii, y_label: psi})
        fig = px.line(df, x=x_label, y=y_label)
        fig.show()

    def plot_density_3D(self, plot_samples=5000):
        self.random_key, subkey = jax.random.split(self.random_key)
        coords_batch = self.sampler(subkey, self.Z, plot_samples)
        densities = vmap(self.ansatz.wavefunction)(self.w, coords_batch) ** 2
        max_density = jnp.max(densities)
        self.random_key, subkey = jax.random.split(self.random_key)
        coords_batch = coords_batch[max_density * jax.random.uniform(subkey, shape=(plot_samples,)) < densities]
        df = pd.DataFrame(coords_batch.reshape((-1, 3)), columns=['x', 'y', 'z'])
        print(df.head())
        fig = px.scatter_3d(df, x='x', y='y', z='z')
        fig.show()
    
    def plot_density_2D(self, pixel_size=0.01, step_size=0.1):
        pass

In [10]:
max_r_hydrogen = 4
hydrogen_config = {
    "random_seed": 1,
    "Z": 1,
    "N_up": 1,
    "N_down": 0,
    # "batch_size": 2000,
    # "num_batches": 50,
    "batch_size": 64,
    "num_batches": 1000,
    "patience": None,
    # "lr": 0.001,
    "lr": 1e-5,
    #"sampling_dist": lambda coords: 1,
    #"sampler": lambda random_key, Z, num_samples: jax.random.ball(random_key, 3, shape=(num_samples, Z)).reshape((num_samples, -1)) * max_r_hydrogen,
    "sampling_dist" : None,
    "sampler" : None,
    "moving_avg_coeff": 0.1,
    "ansatz": {},
    "MH": {
        "stdev": 0.2,
        "warmup": 500,
        "batch_size": 64
    }
    # "regularize": {
    #     "max_r": max_r_hydrogen,
    #     "pow": 8,
    #     "coeff": 1
    # }
}

max_r_helium = 2
helium_config = {
    "random_seed": 0,
    "Z": 2,
    "N_up": 1,
    "N_down": 1,
    # "batch_size": 2000,
    # "num_batches": 25,
    "batch_size": 128,
    "num_batches": 100000,
    "validate_every": 2000,
    "patience": None,
    "lr": optax.warmup_cosine_decay_schedule(5e-5, 5e-4, 100, 100000, end_value=5e-6, exponent=1.0),
    # "sampling_dist": lambda coords: 1,
    # "sampler": lambda random_key, Z, num_samples: jax.random.ball(random_key, 3, shape=(num_samples, Z)).reshape((num_samples, -1)) * max_r_helium,
    "sampling_dist" : None,
    "sampler" : None,
    "ansatz": {},
    "MH": {
        "stdev": 0.2,
        "warmup": 500,
        "interval": 10,
        "batch_size": 64
    }
    # "regularize": {
    #     "max_r": max_r_helium,
    #     "regularize_pow": 8,
    #     "regularize_coeff": 0,
    # }
}

atom_model = E3FerminetAtom(helium_config)
atom_model.train_loop()
atom_model.choose_weights("last")
atom_model.test()
atom_model.plot_one_electron_radial(4)
# atom_model.plot_density_3D()

Output hidden; open in https://colab.research.google.com to view.

In [49]:
e3nn.gate("25x0e+0x0o+5x1e+5x1o+5x2e+5x2o")

5x0e+5x1e+5x1o+5x2e+5x2o

#Importance Sampling

One source of inaccuracy in the Monte Carlo integration is that we were sampling from a uniform ball. But this over-prioritizes points far away from the origin, and points near the nucleus don't get sampled enough. This creates a large variance in the integral. We want to try importance sampling, where we sample from a known distribution and calculate the Hamiltonian at those points. Then each sample gets weighted. This works as follows:

Suppose we are trying to sample the energy $E(x)$ from a distribution $p(x)$, and we want to calculate
$$\langle E \rangle = \int E(x)p(x)dx$$
but $p(x)$ is difficult to sample from (or may not even be normalized!). We pick a sampling distribution $q(x)$, and calculate
$$\int E(x)p(x)dx=\int E(x)q(x)\frac{p(x)}{q(x)}dx \approx \sum_i E(x_i) \frac{p(x)}{q(x)}$$
where the points $x_i$ are sampled from $q$. In the case of the hydrogen atom, the radial distribution function is close enough to a $\chi$-squared distribution, so we use that as a sampling distribution.

Although decently accurate for hydrogen, there's still some distance to the true ground state energy for helium! Here are some possible approaches for improving our accuracy:
- More expressive neural network, e.g. including biases, more layers, more neurons...
- More efficient and accurate sampling
- Better way to impose wavefunction boundary conditions

We have only been using real-valued wavefunctions so far, which is a big limitation as well. So we should probably allow complex-valued wavefunctions sometime soon.

The [paper](https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.2.033429) we're referencing has some quite sophisticated techniques for wavefunction parameterization and optimization, and we're considering incorporating some of them into our model.

# NEXT STEPS

Zed: Continue with current approach
- Refactor
- Implement MH and importance sampling
- Implement envelope

Alec: Equivariant basis function approach

Input $\mathbf R_A$ (positions of all nuclei) (or relative distances?).
Neural network outputs coefficients $c_i$ so that output wavefunction is
$$
\psi = \sum_i c_i \psi_i
$$
Approach 1:
Each electron has a molecular orbital $$\phi_i$$ written as a linear combination of equivariant basis functions
$$
\phi_i = \sum_j c_{ij}{}^l_m B_j{}^l_m
$$
where
$$
B_j{}^l_m = R_j(r) Y^l_m.
$$
(We let origin be arbitrariy, or some "important" atom of the molecule for now.)
The $c_{ij}{}^l_m$ for fixed $i, j, l$ transforms as an $l$-irrep.

Radial basis $R(r)$ can be e.g. Bessel functions multiplied by an exponentially decaying envelope, or radial part of Slater-type orbitals, or radial part of Gaussian orbitals, or even a neural network. These choices may also involve a continuous parameter in the exponent that we optimize.

The final wavefunction is the Slater determinant of molecular orbitals.

Approach 2:
Do Approach 1 around every atom, and sum up all wavefunctions.
So we have a wavefunction $\psi_A$ derived from Approach 1 around each atom, and sum everything together: $\psi = \sum_A \psi_A$.

### Atoms with more electrons?

Beyond helium, we'll have to antisymmetrize the wavefunction for electrons with the same spin. We'll explore this in the next few days.

# Molecules with multiple atoms

If our $SO(3)$-equivariant neural network ansatz works well, we will try to extend our approach to molecules with multiple atoms. Exact details TBD.

#Implementing Equivariance
As a proof of concept, we will implement equivariance on the hydrogen atom. We take the total wavefunction to be the product of a radial part and an angular part, consisting of spherical harmonics which transform eqivariantly. The radial part as before will be invariant.

In [None]:
# Set up model

radial = e3nn.flax.MultiLayerPerceptron([5, 1], act=jax.nn.gelu, output_activation=jax.nn.sigmoid)
angular = e3nn.flax.MultiLayerPerceptron([5, 1], act=jax.nn.gelu, output_activation=jax.nn.sigmoid)
N_samples = 2000
max_r = 4
regularize_pow = 8
regularize_coeff = 100

def wavefunction_H(w, in_points):
  return mlp_H.apply(w, e3nn.tensor_square(e3nn.IrrepsArray("1o", in_points)).filter(keep="0e")).array.squeeze(-1)

@jit
def energy_H(w, in_points):
  psi = wavefunction_H(w, in_points)
  @vmap
  def laplacian(in_points):
    return jnp.einsum('ii->', jacobian(jacobian(wavefunction_H, argnums=1), argnums=1)(w, in_points))
  laplacian_psi = laplacian(in_points)
  cum_K = -0.5 * jnp.dot(psi, laplacian_psi)
  distances = jnp.linalg.norm(in_points, axis=1)
  cum_V = -jnp.dot(psi, psi / distances)
  return (cum_K + cum_V) / jnp.dot(psi, psi)

@jit
def regularized_energy_H(w, in_points):
  psi = wavefunction_H(w, in_points)
  distances = jnp.linalg.norm(in_points, axis=1)
  return energy_H(w, in_points) + regularize_coeff * jnp.dot(psi, psi * (distances / max_r) ** regularize_pow) / jnp.dot(psi, psi)  # penalize high probability near max_r


In [None]:
random_key = jax.random.PRNGKey(0)
in_points = jax.random.ball(random_key, 3, shape=(N_samples,)) * max_r
random_key += 1
x = e3nn.IrrepsArray("1o", in_points)
x = e3nn.tensor_square(x).filter(keep="0e")

w = mlp_H.init(random_key, x)
random_key += 1
%timeit print(jit(regularized_energy_H)(w, in_points))

#Metropolis Hastings
This is a technique to sample directly from a wavefunction. Let $p(x) := |\psi (x)|^2$ be the target distribution to sample from. Given a start point $x$, we generate a proposal point $x'$ from some proposal distribution $q(x; x')$, which could be, for example, a gaussian centered at $x'$. The proposal distribution is symmetric, that is to say $q(x;x') = q(x'; x)$. Then we calculate $a=\min\{1, p(x')/p(x)\}$ and accept the proposal with probability $a$. This produces a random walk with stationary distribution that converges to $p$.

In [None]:
def energy_H(w, N_points = 10000):
  psi = wavefunction_H(w, N_points)
  @vmap
  def laplacian(in_points):
    return jnp.einsum('ii->', jacobian(jacobian(wavefunction_H, argnums=1), argnums=1)(w, in_points))
  laplacian_psi = laplacian(in_points)
  cum_K = -0.5 * jnp.dot(psi, laplacian_psi)
  distances = jnp.linalg.norm(in_points, axis=1)
  cum_V = -jnp.dot(psi, psi / distances)
  return (cum_K + cum_V) / jnp.dot(psi, psi)