In [4]:
import jax.numpy as np
from jax import grad, jit, vmap, jacfwd
from functools import partial

In [27]:
from math import sqrt, exp



In [112]:
### Forces
class ForceModel(object):
    """ Defines a force model to use for integration of trajectories or
    stms

    Args:
        force_list (list[functions]): list of force functions to use for
            model

    """
    def __init__(self, force_list):
        self.force_list = force_list


    def ode(self, t, state_vec):
        """ Differential Equation for State vector
        """
        xddot, yddot, zddot = map(sum, zip(*[fxn(state_vec) for
                                             fxn in self.force_list]))

        out_state = [state_vec[3], state_vec[4], state_vec[5],
                     xddot, yddot, zddot]

        if len(state_vec) > 6:
            out_state.append(-BETA * state_vec[6])
            out_state.append(-BETA * state_vec[7])
            out_state.append(-BETA * state_vec[8])

        return out_state


def point_mass(state_vec):
        """Calculates the x, y, z accelerations due to point
            mass gravity model

        """
        mu = set_mu(state_vec)
        x, y, z = state_vec[0:3]
        r = norm(state_vec[0:3])

        return  [-mu * coord / r**3 for coord in state_vec[0:3]]

def set_mu(state_vec):
    """ """
    mu = state_vec[6] if 6 < len(state_vec) else MU

    return mu

def norm(vec):
    """ Computes the 2 norm of a vector or vector slice """
    return sqrt(sum([i**2 for i in vec]))

MU = 3.986004415e+05 # gravitational parameter of earth


In [148]:
### Forces
class AltForceModel(object):
    def __init__(self, force_list):
        self.force_list = force_list


    def ode(self, t, state_vec):
        xddot, yddot, zddot = map(sum, zip(*[fxn(state_vec) for
                                             fxn in self.force_list]))

        out_state = np.array([state_vec[3], state_vec[4], state_vec[5],
                              xddot, yddot, zddot])

        return out_state


def point_mass(state_vec):
        """Calculates the x, y, z accelerations due to point
            mass gravity model

        """
        mu = MU
        x, y, z = state_vec[0:3]
        r = np.linalg.norm(state_vec[0:3])
    
        vels = state_vec[3:]
        accels = np.array([-mu * coord / r**3 for coord in state_vec[0:3]])
        return np.concatenate((vels, accels))
        
        


MU = 3.986004415e+05 # gravitational parameter of earth


In [149]:
force_model = ForceModel([point_mass])

In [209]:
istate = np.array([1.0,2.0,3.0,4.0,5.0,100.0])

In [158]:
x = jacfwd(point_mass)


In [177]:
from jax import random
def predict(W, b, inputs):
    return sigmoid(np.dot(inputs, W) + b)

def sigmoid(x):
    return 0.5 * (np.tanh(x / 2) + 1)

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

key = random.PRNGKey(0)
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
inputs = np.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

J_dict = jacfwd(predict_dict)({'W': W, 'b': b}, inputs)
J_dict

{'W': array([[ 0.05981753,  0.12883775,  0.08857595],
        [ 0.04015912, -0.0492862 ,  0.00684531],
        [ 0.1218829 ,  0.01406341, -0.30470726],
        [ 0.00140427, -0.00472519,  0.00263776]], dtype=float32),
 'b': array([0.11503371, 0.04563536, 0.2343902 , 0.00189767], dtype=float32)}

In [218]:
def pm(state_vec):
    r = np.linalg.norm(state_vec[0:3])
    x, y, z, dx, dy, dz = state_vec
    
    return np.array([dx, dy, dz, 
                           -MU * x / r**3, 
                           -MU * y / r**3, 
                           -MU * z / r**3])
    


In [224]:
pm(istate)
istate

array([  1.,   2.,   3.,   4.,   5., 100.], dtype=float32)

In [220]:
jacfwd(pm)(istate)

RuntimeError: Invalid argument: Operands to select must be the same shape; got f32[] and f32[6].: 

In [119]:
import ad
from ad import jacobian
ad_state = ad.adnumber(istate)
state_deriv = force_model.ode(0, ad_state)
a_matrix = jacobian(state_deriv, ad_state)
a_matrix

[[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
 [-7609.317787295047, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.0, -7609.317787295047, 0.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, -7609.317787295047, 0.0, 0.0, 0.0]]