In [1]:
import jax
from jax import numpy as jnp

In [2]:
class PositionalQuaternionNN(eqx.Module):
    layers: list
    final_laryer_r: eqx.Module
    final_laryer_theta: eqx.Module
    final_laryer_v: eqx.Module
        
    
    def __init__(self, key, hidden_dim = 128, hidden_num = 4):
        # Create random keys for initializing weights
        key1, key2, key3, key4, key5 = jax.random.split(5)
        
        self.layers = []
        self.layers.append(eqx.nn.Linear(1, hidden_dim, key = key1))
        for hidden_idx in range(hidden_num):
            key2, subkey = jax.random.split(key2)
            self.layers.append(eqx.nn.Linear(hidden_dim, hidden_dim, key = key2))
        
        # Define the final layer for r(t), theta, and v
        self.final_laryer_r = eqx.nn.Linear(hidden_dim, 3, key = key3)
        self.final_laryer_theta = eqx.nn.Linear(hidden_dim, 1, key = key4)
        self.final_laryer_v = eqx.nn.Linear(hidden_dim, 3, key = key5)
        
        
    def __call__(self, t):
        # Forward pass through shared layers
        
        x = t
        for layer in self.layers:
            x = jax.nn.sigmoid(layer(x))
            
        # Compute r(t)
        r_t = self.final_laryer_r(x)
        
        
        # Compute theta
        theta = self.final_laryer_theta(x)
        
        # Compute v (unit vector part of the quaternion)
        v = self.final_laryer_v(x)
        v /= jnp.linalg.norm(v, axis = -1, keepdims=True) # Normalize to ensure a unit vector
        
        # Angle components of the quaternion
        scalar_part = jnp.cos(theta)
        vector_part = v * jnp.sin(theta)
        
        q_t = jnp.concatenate([scalar_part, vector_part], axis = -1)
        
        return r_t, q_t
    

# Define the physical law for gyroscope
def gyroscope_model(pose, t):
    
    dq_dt = jax.jacrev(lambda t: pose(t)[1])
    d_dot = dq_dt(t)
    
    _, q_t = pose(t)
    q_t_conjugate = quaternion_conjugate(q_t)
    
    q_omega = quaternion_product(q_dot, q_t_conjugate)
    _, vec_omega = quaternion_scalar_and_vector(q_omega)
    omega = 2 * vec_omega
    
    return omega

# Define the physical law for accelerometer
def accelerometer_model(pose, t):
    
    r_q, q_t = pose(t)
    q_t_conjugate = quaternion_conjugate(q_t)
    
    d2r_dt2 = jax.jacrev(jax.jacrev(lambda t: pose(t)[0]))
    gravity_acc = jnp.array([0, 0, -9.81])
    
    vec_acc_true = d2r_dt2(t).squeeze() - gravity_acc
    q_acc_true = jnp.concatenate([jnp.array([0]), vel_acc_true], axis = -1)
    
    q_acc = quaternion_product(q_t_conjugate, quaternion_product(q_acc_true, q_t))
    
    _, acc = quaternion_scalar_and_vector(q_acc)
    
    return acc

NameError: name 'eqx' is not defined