In [2]:
import jax.numpy as jnp
from itertools import chain, combinations, combinations_with_replacement


In [3]:
cell_basis = jnp.array([[-0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [0.5, 0.5, -0.5]])
cv1, cv2, cv3 = cell_basis.split(3, axis=1)

n = 2
vector_set = [cv1, -cv1, cv2, -cv2, cv3, -cv3]
lattice_vectors = list(chain.from_iterable(combinations_with_replacement(vector_set, i) for i in range(1, n+1)))
lattice_vectors = [jnp.sum(jnp.concatenate(x, axis=-1), axis=-1) for x in lattice_vectors]
lattice_vectors = jnp.array([x for x in lattice_vectors if not jnp.sum(jnp.zeros(3,) == x) == 3])
print(lattice_vectors.shape)

(24, 3)


In [26]:
print(cv2.shape)
cross = jnp.cross(cv2, cv3, axisa=0, axisb=0)
print(cross)
box = cross @ cv1
print(box)
volume = jnp.abs(box.squeeze())
print(volume)

(3, 1)
[[0.  0.5 0.5]]
[[0.5]]
0.5


In [52]:
print(len(lattice_vectors))
for l in lattice_vectors:
    print(l)
    

24
[-0.5  0.5  0.5]
[ 0.5 -0.5 -0.5]
[ 0.5 -0.5  0.5]
[-0.5  0.5 -0.5]
[ 0.5  0.5 -0.5]
[-0.5 -0.5  0.5]
[-1.  1.  1.]
[0. 0. 1.]
[-1.  1.  0.]
[0. 1. 0.]
[-1.  0.  1.]
[ 1. -1. -1.]
[ 1. -1.  0.]
[ 0.  0. -1.]
[ 1.  0. -1.]
[ 0. -1.  0.]
[ 1. -1.  1.]
[1. 0. 0.]
[ 0. -1.  1.]
[-1.  1. -1.]
[ 0.  1. -1.]
[-1.  0.  0.]
[ 1.  1. -1.]
[-1. -1.  1.]


In [45]:
n = 2
vector_set = [1, 2, 3, 4, 5, 6]
lattice_vectors = list(combinations_with_replacement(vector_set, 2))
lattice_vectors = list(chain.from_iterable(combinations_with_replacement(vector_set, i) for i in range(1, n+1)))
print(len(lattice_vectors))
for l in lattice_vectors:
    print(l)

27
(1,)
(2,)
(3,)
(4,)
(5,)
(6,)
(1, 1)
(1, 2)
(1, 3)
(1, 4)
(1, 5)
(1, 6)
(2, 2)
(2, 3)
(2, 4)
(2, 5)
(2, 6)
(3, 3)
(3, 4)
(3, 5)
(3, 6)
(4, 4)
(4, 5)
(4, 6)
(5, 5)
(5, 6)
(6, 6)


In [None]:

def create_potential_energy(mol):
    if mol.periodic_boundaries:
        cv1, cv2, cv3 = mol.cell_basis.split(3, axis=1)
        # translations = all_translations(cv1, cv2, cv3)
        # translations.extend(all_translations(-cv1, cv2, cv3))
        # translations.extend([cv2, cv2 + cv3, cv2 - cv3, -cv2, -cv2 + cv3, -cv2 - cv3])
        # translations.extend([cv3, -cv3])
        # translation_vectors = jnp.concatenate(translations, axis=-1).transpose()
        # translation_vectors = jnp.expand_dims(translation_vectors, axis=0)
        #
        n = 3
        vector_set = [cv1, -cv1, cv2, -cv2, cv3, -cv3]
        lattice_vectors = list(
            chain.from_iterable(combinations_with_replacement(vector_set, i) for i in range(1, n + 1)))
        lattice_vectors = [jnp.sum(jnp.concatenate(x, axis=-1), axis=-1) for x in lattice_vectors]
        lattice_vectors = jnp.array([x for x in lattice_vectors if not jnp.sum(jnp.zeros(3, ) == x) == 3])

        def compute_potential_energy_solid_i(walkers, r_atoms, z_atoms):

            """
            :param walkers (n_el, 3):
            :param r_atoms (n_atoms, 3):
            :param z_atoms (n_atoms, ):

            Pseudocode:
                - compute the potential energy (pe) of the cell
                - compute the pe of the cell electrons with electrons outside
                - compute the pe of the cell electrons with nuclei outside
                - compute the pe of the cell nuclei with nuclei outside
            """

            ex_walkers = (jnp.expand_dims(walkers, axis=1) + lattice_vectors).reshape(-1, 3)  # (n_el * 26, 3)
            ex_r_atoms = (jnp.expand_dims(r_atoms, axis=1) + lattice_vectors).reshape(-1, 3)  # (n_atom * 26, 3)
            ex_z_atoms = jnp.expand_dims(z_atoms, axis=0).repeat(len(lattice_vectors), axis=0)  # (n_atom * 26, 1)

            potential_energy = compute_potential_energy_i(walkers, r_atoms, z_atoms)

            ex_e_e_dist = batched_cdist_l2(walkers, ex_walkers)
            potential_energy += jnp.sum(1. / ex_e_e_dist)

            ex_a_e_dist = batched_cdist_l2(walkers, ex_r_atoms)
            potential_energy -= jnp.sum(ex_z_atoms / ex_a_e_dist)

            ex_a_a_dist = batched_cdist_l2(r_atoms, ex_r_atoms)
            potential_energy += (z_atoms[None, :] * ex_z_atoms) / ex_a_a_dist

            return potential_energy

        # return vmap(compute_potential_energy_i, in_axes=(0, None, None))
        return vmap(compute_potential_energy_solid_i, in_axes=(0, None, None))

    return vmap(compute_potential_energy_i, in_axes=(0, None, None))


def batched_cdist_l2(x1, x2):

    x1_sq = jnp.sum(x1 ** 2, axis=-1, keepdims=True)
    x2_sq = jnp.sum(x2 ** 2, axis=-1, keepdims=True)
    cdist = jnp.sqrt(jnp.swapaxes(x1_sq, -1, -2) + x2_sq \
                     - jnp.sum(2 * jnp.expand_dims(x1, axis=0) * jnp.expand_dims(x2, axis=1), axis=-1))
    return cdist





def compute_potential_energy_i(walkers, r_atoms, z_atoms):
    """

    :param walkers (n_el, 3):
    :param r_atoms (n_atoms, 3):
    :param z_atoms (n_atoms, ):
    :return:

    pseudocode:
        - compute potential energy contributions
            - electron - electron interaction
            - atom - electron interaction
            - atom - atom interation
    """

    n_atom = r_atoms.shape[0]

    e_e_dist = batched_cdist_l2(walkers, walkers)
    potential_energy = jnp.sum(jnp.tril(1. / e_e_dist, k=-1))

    a_e_dist = batched_cdist_l2(r_atoms, walkers)
    potential_energy -= jnp.sum(z_atoms / a_e_dist)

    if n_atom > 1:
        a_a_dist = batched_cdist_l2(r_atoms, r_atoms)
        weighted_a_a = (z_atoms[:, None] * z_atoms[None, :]) / a_a_dist
        unique_a_a = jnp.tril(weighted_a_a, k=-1)
        potential_energy += jnp.sum(unique_a_a)

    return potential_energy