In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp

import e3nn_jax as e3nn

np.set_printoptions(precision=3, suppress=True)

In [3]:
# need last commit of e3nn_jax
from e3nn_jax._src.spherical_harmonics.recursive import recursive_spherical_harmonics

x = sp.symbols("x:3")
y = recursive_spherical_harmonics(2, {}, jnp.ones(3), "component", "dense")
y = sp.simplify(y)
display(y)
display(sp.Matrix(y).jacobian(x))

[0, 0, -1/2, 0, -sqrt(3)/2]

Matrix([
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])

In [8]:
def y2(x):
    # explicit formula for e3nn.spherical_harmonics(2, x, False, "component")
    x0, x1, x2 = x
    a = jnp.sqrt(15)
    b = jnp.sqrt(5)
    return jnp.array(
        [
            a * x0 * x2,
            a * x0 * x1,
            b * (-(x0**2) + 2 * x1**2 - x2**2) / 2,
            a * x1 * x2,
            a * (-(x0**2) + x2**2) / 2,
        ]
    )


def y2_jac(x):
    # explicit formula for jax.jacfwd(y2)(x)
    x0, x1, x2 = x
    a = jnp.sqrt(15)
    b = jnp.sqrt(5)
    return jnp.array(
        [
            [a * x2, 0, a * x0],
            [a * x1, a * x0, 0],
            [-b * x0, 2 * b * x1, -b * x2],
            [0, a * x2, a * x1],
            [-a * x0, 0, a * x2],
        ]
    )


def _y2_inverse(y):
    # y2_inverse(y2(x)) == x or -x
    # y2(y2_inverse(y)) == y
    assert y.shape == (5,)
    A = e3nn.generators(2) @ y
    A = jnp.conj(A) @ A.T
    val, vec = jnp.linalg.eigh(A)
    x = vec.T[0]  # first is the smallest eigenvalue
    safe_sqrt = lambda x: jnp.sqrt(jnp.maximum(x, 1e-7))
    x = x * safe_sqrt(safe_sqrt(jnp.mean(y**2)))
    return x, val[0]


# Custom JVP rule for x_dot

# !!! Fall back to JVP rule of jnp.linalg.eigh for val_dot !!!

@jax.custom_jvp
def y2_inverse(y):
    return _y2_inverse(y)

@y2_inverse.defjvp
def y2_inverse_jvp(primals, tangents):
    (y,) = primals
    (y_dot,) = tangents
    x, val = y2_inverse(y)
    x_dot = jnp.linalg.lstsq(y2_jac(x), y_dot)[0]

    _, (_, val_dot) = jax.jvp(_y2_inverse, (y,), (y_dot,))
    return (x, val), (x_dot, val_dot)

In [9]:
x = np.random.randn(3)
print("x", x)

print("y2(x)", y2(x))
print(
    "e3nn.sh(2, x, False, 'component')",
    e3nn.sh(2, x, False, "component"),
)
print("y2_jac(x)\n", y2_jac(x))
print("jax.jacfwd(y2)(x)\n", jax.jacfwd(y2)(x))

x [ 0.655 -1.735  0.998]
y2(x) [ 2.53  -4.399  5.138 -6.704  1.098]
e3nn.sh(2, x, False, 'component') [ 2.53  -4.399  5.138 -6.704  1.098]
y2_jac(x)
 [[ 3.864  0.     2.535]
 [-6.719  2.535  0.   ]
 [-1.464 -7.758 -2.231]
 [ 0.     3.864 -6.719]
 [-2.535  0.     3.864]]
jax.jacfwd(y2)(x)
 [[ 3.864  0.     2.535]
 [-6.719  2.535  0.   ]
 [-1.464 -7.758 -2.231]
 [ 0.     3.864 -6.719]
 [-2.535  0.     3.864]]


In [10]:
x = np.random.randn(3)
print("x", x)

y_dot = np.random.randn(5)
x_dot = jnp.linalg.lstsq(y2_jac(x), y_dot)[0]

y_dot = y2_jac(x) @ x_dot - y_dot
x_dot = jnp.linalg.lstsq(y2_jac(x), y_dot)[0]
print("x_dot", x_dot)

x [-1.564  1.291 -1.147]
x_dot [ 0. -0. -0.]


In [15]:
x = np.random.randn(3)
print("x", x)
y = y2(x)

x2, val = y2_inverse(y)
print("x2", x2)
print("val", val)

x2, val = y2_inverse(y + 0.005)
print("x2", x2)
print("val", val)

jax.jvp(y2_inverse, (y + 0.015,), (np.random.randn(5),))

x [ 1.354 -0.047  0.979]
x2 [ 1.354 -0.047  0.979]
val 4.6748164e-06
x2 [ 1.353 -0.046  0.98 ]
val 0.00015707265


((Array([ 1.352, -0.043,  0.982], dtype=float32), Array(0.001, dtype=float32)),
 (Array([-0.287,  0.096,  0.04 ], dtype=float32), Array(0.278, dtype=float32)))

In [18]:
jax.grad(lambda y: y2_inverse(y)[0].sum(-1))(y + 0.015)

NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'integer_pow' not implemented