In [1]:
import jax
import jax.numpy as jnp

In [2]:
N = 3
R = 1.4632
ZETA1 = 2.0925
ZETA2 = 1.24
ZA = 2.0
ZB = 1.0

In [3]:
COEFFICIENTS = jnp.array([[1.000000, 0.000000, 0.000000],
                          [0.678914, 0.430129, 0.000000],
                          [0.444635, 0.535328, 0.154329]])

In [4]:
EXPONENTS = jnp.array([[0.270950, 0.000000, 0.00000],
                       [0.151623, 0.851819, 0.00000],
                       [0.109818, 0.405771, 2.22766]])

In [5]:
A1 = EXPONENTS[N-1, :] * jnp.power(ZETA1, 2)

In [6]:
A2 = EXPONENTS[N-1, :] * jnp.power(ZETA2, 2)

In [7]:
D1 = COEFFICIENTS[N-1, :] * jnp.power(2 * A1 / jnp.pi, 0.75)

In [8]:
D2 = COEFFICIENTS[N-1, :] * jnp.power(2 * A2 / jnp.pi, 0.75)

In [9]:
def overlap_integral(a, b, r):
    return jnp.power(jnp.pi / (a + b), 1.5) * jnp.exp((-1 * a * b * r) / (a + b))

In [10]:
S12 = jnp.sum(overlap_integral(A1, A2[:, jnp.newaxis], R*R) * jnp.einsum('i,j', D1, D2))
S12

DeviceArray(0.45077044, dtype=float32)

In [11]:
def kinetic_energy(a, b, r):
    return (a * b) / (a + b) * (3 - (2 * a * b * r) / (a + b)) * overlap_integral(a, b, r)

In [12]:
T11 = jnp.sum(kinetic_energy(A1, A1[:, None], 0.0) * jnp.einsum('i,j', D1, D1))
T12 = jnp.sum(kinetic_energy(A1, A2[:, None], R*R) * jnp.einsum('i,j', D1, D2))
T22 = jnp.sum(kinetic_energy(A2, A2[:, None], 0.0) * jnp.einsum('i,j', D2, D2))
(T11, T12, T22)

(DeviceArray(2.1643124, dtype=float32),
 DeviceArray(0.16701286, dtype=float32),
 DeviceArray(0.76003295, dtype=float32))

In [13]:
def nuclear_attraction(a, b, r1, r2, z):
    return -1 * z * 2 * jnp.pi / (a + b) * f0((a + b) * r2) * jnp.exp(-1 * a * b * r1 / (a + b))

In [14]:
def f0(x):
    return jnp.where(x < 1e-6,
                     1.0 - x/3.0,
                     jnp.sqrt(jnp.pi / x) * jax.scipy.special.erf(jnp.sqrt(x)) / 2)

In [29]:
rAP = A2[:, None] * R / (A1 + A2[:, None])
rBP = R - rAP
V11a = jnp.sum(nuclear_attraction(A1, A1[:, None],   0,       0, ZA) * jnp.einsum('i,j', D1, D1))
V12a = jnp.sum(nuclear_attraction(A1, A2[:, None], R*R, rAP*rAP, ZA) * jnp.einsum('i,j', D1, D2))
V22a = jnp.sum(nuclear_attraction(A2, A2[:, None],   0,     R*R, ZA) * jnp.einsum('i,j', D2, D2))
V11b = jnp.sum(nuclear_attraction(A1, A1[:, None],   0,     R*R, ZB) * jnp.einsum('i,j', D1, D1))
V12b = jnp.sum(nuclear_attraction(A1, A2[:, None], R*R, rBP*rBP, ZB) * jnp.einsum('i,j', D1, D2))
V22b = jnp.sum(nuclear_attraction(A2, A2[:, None],   0,       0, ZB) * jnp.einsum('i,j', D2, D2))
(V11a, V12a, V22a, V11b, V12b, V22b)

(DeviceArray(-4.139827, dtype=float32),
 DeviceArray(-1.1029125, dtype=float32),
 DeviceArray(-1.2652458, dtype=float32),
 DeviceArray(-0.6772301, dtype=float32),
 DeviceArray(-0.4113055, dtype=float32),
 DeviceArray(-1.2266155, dtype=float32))

In [30]:
def two_electron_integral(a, b, c, d, r1, r2, r3):
    return 2 * (jnp.power(jnp.pi, 2.5) / ((a + b) * (c + d) * jnp.sqrt(a + b + c + d))
                * f0((a + b) * (c + d) * r3 / (a + b + c + d))
                * jnp.exp(-1 * a * b * r1 / (a + b) - c * d * r2 / (c + d)))

In [31]:
rAP = A2 * R / (A2 + A1[:, None])
rAQ = A2[:, None, None] * R / (A2[:, None, None] + A1[:, None, None, None])
rPQ = rAP - rAQ
rBQ = R - rAQ
V1111 = jnp.sum(two_electron_integral(A1, A1[:, None], A1[:, None, None], A1[:, None, None, None],   0,   0,       0) * jnp.einsum('i,j,k,l->ijkl', D1, D1, D1, D1))
V2111 = jnp.sum(two_electron_integral(A2, A1[:, None], A1[:, None, None], A1[:, None, None, None], R*R,   0, rAP*rAP) * jnp.einsum('i,j,k,l->ijkl', D2, D1, D1, D1))
V2121 = jnp.sum(two_electron_integral(A2, A1[:, None], A2[:, None, None], A1[:, None, None, None], R*R, R*R, rPQ*rPQ) * jnp.einsum('i,j,k,l->ijkl', D2, D1, D2, D1))
V2211 = jnp.sum(two_electron_integral(A2, A2[:, None], A1[:, None, None], A1[:, None, None, None],   0,   0,     R*R) * jnp.einsum('i,j,k,l->ijkl', D2, D2, D1, D1))
V2221 = jnp.sum(two_electron_integral(A2, A2[:, None], A2[:, None, None], A1[:, None, None, None],   0, R*R, rBQ*rBQ) * jnp.einsum('i,j,k,l->ijkl', D2, D2, D2, D1))
V2222 = jnp.sum(two_electron_integral(A2, A2[:, None], A2[:, None, None], A2[:, None, None, None],   0,   0,       0) * jnp.einsum('i,j,k,l->ijkl', D2, D2, D2, D2))
(V1111, V2111, V2121, V2211, V2221, V2222)

(DeviceArray(1.3071517, dtype=float32),
 DeviceArray(0.43727934, dtype=float32),
 DeviceArray(0.1772671, dtype=float32),
 DeviceArray(0.6057035, dtype=float32),
 DeviceArray(0.31179467, dtype=float32),
 DeviceArray(0.7746078, dtype=float32))

In [18]:
H0 = jnp.array([[T11 + V11a + V11b, T12 + V12a + V12b],
                [T12 + V12a + V12b, T22 + V22a + V22b]])
H0

DeviceArray([[-2.6527445, -1.3472052],
             [-1.3472052, -1.7318285]], dtype=float32)

In [19]:
S = jnp.array([[  1, S12],
               [S12,   1]])
S

DeviceArray([[1.        , 0.45077044],
             [0.45077044, 1.        ]], dtype=float32)

In [20]:
X = jnp.array([[1 / jnp.sqrt(2 * (1 + S12)),   1 / jnp.sqrt(2 * (1 - S12))],
               [1 / jnp.sqrt(2 * (1 + S12)), - 1 / jnp.sqrt(2 * (1 - S12))]])
X

DeviceArray([[ 0.58706427,  0.95413107],
             [ 0.58706427, -0.95413107]], dtype=float32)

In [21]:
TT = jnp.zeros((2, 2, 2, 2))
TT = TT.at[0, 0, 0, 0].set(V1111)
TT = TT.at[1, 0, 0, 0].set(V2111)
TT = TT.at[0, 1, 0, 0].set(V2111)
TT = TT.at[0, 0, 1, 0].set(V2111)
TT = TT.at[0, 0, 0, 1].set(V2111)
TT = TT.at[1, 0, 1, 0].set(V2121)
TT = TT.at[0, 1, 1, 0].set(V2121)
TT = TT.at[1, 0, 0, 1].set(V2121)
TT = TT.at[0, 1, 0, 1].set(V2121)
TT = TT.at[1, 1, 0, 0].set(V2211)
TT = TT.at[0, 0, 1, 1].set(V2211)
TT = TT.at[1, 1, 1, 0].set(V2221)
TT = TT.at[1, 1, 0, 1].set(V2221)
TT = TT.at[1, 0, 1, 1].set(V2221)
TT = TT.at[0, 1, 1, 1].set(V2221)
TT = TT.at[1, 1, 1, 1].set(V2222)
TT

DeviceArray([[[[1.3071517 , 0.43727934],
               [0.43727934, 0.6057035 ]],

              [[0.43727934, 0.1772671 ],
               [0.1772671 , 0.31179467]]],


             [[[0.43727934, 0.1772671 ],
               [0.1772671 , 0.31179467]],

              [[0.6057035 , 0.31179467],
               [0.31179467, 0.7746078 ]]]], dtype=float32)

In [22]:
MAX_ITER = 15

In [23]:
def G(p, tt):
    return jnp.einsum('kl,ijkl->ij', p, tt) - jnp.einsum('kl,ilkj->ij', p, 0.5 * tt)

In [24]:
P = jnp.zeros((2, 2))
for n_iter in range(MAX_ITER):
    g = G(P, TT)
    F = H0 + g
    
    # Electronic energy
    e_E = jnp.sum(0.5 * P * (H0 + F))

    Fp = X.T @ F @ X
    E, Cp = jnp.linalg.eig(Fp.T)
    C = X @ Cp
    
    old_P = P.copy()
    P = 2 * jnp.einsum('i,j', C[:, 0], C[:, 0])

    delta = jnp.sqrt(jnp.power(P - old_P, 2).sum() / 4)
    
    if delta < 1e-6:
        break

In [25]:
TOTAL_E = e_E + ZA * ZB / R

In [26]:
print(TOTAL_E)

(-2.860662+0j)
