In [78]:
import jax
import jax.numpy as jnp
from abc import ABC, abstractmethod


# we will abuse terminology and take frame and coordinate system to be roughly synonymous

class CoordinateSystem(ABC):
    @abstractmethod
    def to_local(self, coords):
        pass

    @abstractmethod
    def to_global(self, coords):
        pass

class AffineFrame(CoordinateSystem):
    def __init__(self, R, t):
        self.R = R # 3*3 
        self.Ri = jnp.linalg.inv(R)
        self.t = t # 3*1

    def to_local(self, coords):
        return jnp.matmul(coords - self.t, self.Ri)

    def to_global(self, coords):
        """
        coords: n_batch * n_points * 3 
        """
        return self.t + jnp.matmul(coords, self.R)

# class CartesianFrame(AffineFrame):
  
class SphericalFrame(CoordinateSystem):
    def __init__(self):
        pass
    
    def to_local(self, coords):
        x, y, z = coords[..., 0:1], coords[..., 1:2], coords[..., 2:3]
        r = jnp.sqrt(x*x+y*y+z*z)
        theta = jnp.arctan(y/x)
        phi = jnp.arccos(z/r)
        return jnp.concatenate([r, theta, phi], axis=-1)
    
    def to_global(self, coords):
        r, theta, phi = coords[..., 0:1], coords[..., 1:2], coords[..., 2:3]
        x = r * jnp.cos(theta) * jnp.sin(phi)
        y = r * jnp.sin(theta) * jnp.sin(phi)
        z = r * jnp.cos(phi)
        return jnp.concatenate([x, y, z], axis=-1)
    

In [12]:
import jax.numpy as jnp

# Create two matrices of shape (batch_size, m, n) and (batch_size, n, p)
batch_size = 10
m, n, p = 3, 4, 5
a = jnp.ones((batch_size, m, n))
b = jnp.ones((batch_size, n, p))

# Perform batch multiplication between a and b
c = jnp.matmul(a, b)

# The resulting array c will have shape (batch_size, m, p)
print(c.shape)


(10, 3, 5)


In [7]:
n_batch = 
n_points = 100

rng_key = jax.random.PRNGKey(42)
Xlocal = jax.random.uniform(rng_key, (n_batch, n_points, 3))
Xlocal

Array([[[0.56263244, 0.08879936, 0.3812467 ],
        [0.51138127, 0.6424105 , 0.75233936],
        [0.3138579 , 0.22232044, 0.23634052],
        ...,
        [0.29452014, 0.34971   , 0.03020287],
        [0.27510083, 0.6130575 , 0.07285404],
        [0.59405863, 0.16302896, 0.29600763]],

       [[0.56494856, 0.13090682, 0.16035461],
        [0.0570246 , 0.19405067, 0.6212238 ],
        [0.65211797, 0.4940759 , 0.0380615 ],
        ...,
        [0.16399229, 0.2171905 , 0.25922263],
        [0.9801954 , 0.3604412 , 0.6917523 ],
        [0.88947237, 0.06485546, 0.76671696]],

       [[0.69712794, 0.6879561 , 0.72874844],
        [0.48716807, 0.94218576, 0.91240275],
        [0.6258333 , 0.08147645, 0.40676272],
        ...,
        [0.600199  , 0.30211222, 0.3536942 ],
        [0.47230816, 0.2582667 , 0.6980766 ],
        [0.6558392 , 0.20156443, 0.12147629]],

       ...,

       [[0.58134127, 0.52670527, 0.84955776],
        [0.5949209 , 0.6576251 , 0.82705545],
        [0.42576528, 0

In [13]:
R = jnp.array([[1, 0, 0],
               [1, 1, 0],
               [1, 1, 1]])
t = jnp.array([0, 0, 0])
frame1 = AffineFrame(R, t)

In [18]:
jnp.isclose(frame1.to_local(frame1.to_global(Xlocal)), Xlocal).all()

Array(True, dtype=bool)

In [36]:
# testing two ways of xyz to IC

def dist(x1, x2):
    d = x2-x1
    d2 = jnp.sum(d*d, axis=-1)
    return jnp.sqrt(d2)

def angle(x1, x2, x3, degrees=True):
    ba = x1 - x2
    ba /= jnp.linalg.norm(ba, axis=-1, keepdims=True)
    bc = x3 - x2
    bc /= jnp.linalg.norm(bc, axis=-1, keepdims=True)
    cosine_angle = jnp.sum(ba*bc, axis=-1)
    if degrees:
        angle = jnp.degrees(jnp.arccos(cosine_angle)) # Range [0,180]
        return angle 
    else:  # Range [0, pi]
        return jnp.arccos(cosine_angle) 
    
def torsion(x1, x2, x3, x4):
    """Praxeolitic formula
    1 sqrt, 1 cross product"""
    b0 = -1.0*(x2 - x1)
    b1 = x3 - x2
    b2 = x4 - x3
    # normalize b1 so that it does not influence magnitude of vector
    # rejections that come next
    b1 /= jnp.linalg.norm(b1, axis=-1, keepdims=True)

    # vector rejections
    # v = projection of b0 onto plane perpendicular to b1
    #   = b0 minus component that aligns with b1
    # w = projection of b2 onto plane perpendicular to b1
    #   = b2 minus component that aligns with b1
    v = b0 - jnp.sum(b0*b1, axis=-1, keepdims=True) * b1
    w = b2 - jnp.sum(b2*b1, axis=-1, keepdims=True) * b1

    # angle between v and w in a plane is the torsion angle
    # v and w may not be normalized but that's fine since tan is y/x
    x = jnp.sum(v*w, axis=-1)
    b1xv = jnp.cross(b1, v, axisa=-1, axisb=-1)
    y = jnp.sum(b1xv*w, axis=-1)
    return jnp.degrees(jnp.arctan2(y, x))


In [201]:
x1 = jnp.array([[0,0,0]])
x2 = jnp.array([[1,1,0]])
x3 = jnp.array([[1,2,0]])
x4 = jnp.array([[1,2,1]])

x1_long = jnp.tile(jnp.array([[0,0,0]]), (100,1))
x2_long = jnp.tile(jnp.array([[1,1,0]]), (100,1))
x3_long = jnp.tile(jnp.array([[1,2,0]]), (100,1))
x4_long = jnp.tile(jnp.array([[1,2,1]]), (100,1))

In [228]:
@jax.jit
def compute_CIs_no_frames():
    r = dist(x4_long, x3_long)
    theta = angle(x4_long, x3_long, x2_long)
    phi = torsion(x4_long, x3_long, x2_long, x1_long)

In [229]:
%%timeit
compute_CIs_no_frames()

136 µs ± 855 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [240]:
## second way is to calculate frame explicitly

def build_frame(x1, x2, x3):
    """
    x1, x2, x3 are the three atoms before x4, whose coordinate we want to express 
    in the local frame computed from x1, x2, x3. This algorithm is based on NeRF.
    
    x1: n_atoms * 3 array
    """
    bc = x3 - x2
    bc = bc / jnp.linalg.norm(bc, axis=-1, keepdims=True) # normalize
    
    n = jnp.cross(bc, x1-x2)
    n = n / jnp.linalg.norm(n, axis=-1, keepdims=True)
    
    nxbc = jnp.cross(n, bc, axisa=-1, axisb=-1)
    
    coordaxes = jnp.concatenate([nxbc[..., None], n[..., None], bc[..., None]], axis=2)
        
    return AffineFrame(R=coordaxes, t=x3[..., None, :])
    

In [245]:
local_frame = build_frame(x1_long, x2_long, x3_long)

In [246]:
local_frame.t.shape

(100, 1, 3)

In [247]:
local_frame.R.shape

(100, 3, 3)

In [188]:
sframe = SphericalFrame()

In [189]:
jnp.isclose(sframe.to_global(sframe.to_local(Xlocal)), Xlocal, atol=1e-5).all()

Array(True, dtype=bool)

In [190]:
jnp.isclose(sframe.to_local(sframe.to_global(Xlocal)), Xlocal, atol=1e-3).all()

Array(True, dtype=bool)

In [None]:
x4_local = local_frame.to_local(x4_long)
x4_local.shape

In [192]:
sframe.to_local(x4_local)

Array([[[1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        ...,
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964]],

       [[1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        ...,
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964]],

       [[1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        ...,
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964]],

       ...,

       [[1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        [1.       , 1.5707964, 1.5707964],
        ...,
        [1.       , 1.5707

In [224]:
@jax.jit
def compute_CIs():
    local_frame = AffineFrame(R=build_frame(x1, x2, x3), t=x3[None, ...])
    sframe.to_local(local_frame.to_local(x4))

In [226]:
%%timeit
compute_CIs()

138 µs ± 2.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [236]:
import cProfile

cProfile.run('compute_CIs()')


         370 function calls (369 primitive calls) in 0.002 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<lambda>)
        1    0.000    0.000    0.001    0.001 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 <string>:2(__init__)
        2    0.000    0.000    0.000    0.000 abc.py:117(__instancecheck__)
        1    0.000    0.000    0.001    0.001 api.py:304(infer_params)
        1    0.000    0.000    0.000    0.000 api_util.py:273(_ensure_inbounds)
        1    0.000    0.000    0.000    0.000 api_util.py:288(argnums_partial_except)
        1    0.000    0.000    0.000    0.000 api_util.py:590(debug_info)
        1    0.000    0.000    0.000    0.000 api_util.py:600(fun_sourceinfo)
        1    0.000    0.000    0.000    0.000 api_util.py:611(_arg_names)
        1    0.000    0.000    0.000    0.000 api_util.py:616(<listcomp>)
        1    0.0

In [237]:
cProfile.run('compute_CIs_no_frames()')

         370 function calls (369 primitive calls) in 0.001 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<lambda>)
        1    0.000    0.000    0.001    0.001 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 <string>:2(__init__)
        2    0.000    0.000    0.000    0.000 abc.py:117(__instancecheck__)
        1    0.000    0.000    0.001    0.001 api.py:304(infer_params)
        1    0.000    0.000    0.000    0.000 api_util.py:273(_ensure_inbounds)
        1    0.000    0.000    0.000    0.000 api_util.py:288(argnums_partial_except)
        1    0.000    0.000    0.000    0.000 api_util.py:590(debug_info)
        1    0.000    0.000    0.000    0.000 api_util.py:600(fun_sourceinfo)
        1    0.000    0.000    0.000    0.000 api_util.py:611(_arg_names)
        1    0.000    0.000    0.000    0.000 api_util.py:616(<listcomp>)
        1    0.0