In [1]:
import jax
import jax.numpy as jnp
from jax import core
import numpy as np

import casadi as cs

In [2]:
def jaxpr_to_casadi(jaxpr, inputs):
    def interpret_jaxpr(jaxpr, *args):
        invars = args

        if 'constvars' in jaxpr:
            return ca.MX(jaxpr['constvars'][0])

        elif jaxpr['primitive'] == 'add':
            return interpret_jaxpr(jaxpr['invars'][0], *args) + interpret_jaxpr(jaxpr['invars'][1], *args)

        elif jaxpr['primitive'] == 'mul':
            return interpret_jaxpr(jaxpr['invars'][0], *args) * interpret_jaxpr(jaxpr['invars'][1], *args)

        # Add more cases for other JAX primitives as needed.

        else:
            raise NotImplementedError(f"Unsupported JAX primitive: {jaxpr['primitive']}")

    # Create a dummy input signature based on the provided inputs
    # input_signature = [jax.core.ShapedArray(val.shape, jax.core.Literal(aval=val)) for val in inputs]

    # Trace the function to get jaxpr
    # jaxpr_maker = jax.core.TraceableFunction(interpret_jaxpr, False)
    # tracers_in = [core.Tracer(jaxpr_maker, core.TracerTuple(args, {})) for args in zip(*input_signature)]
    # result, _ = core.trace_to_jaxpr(jaxpr_maker, tracers_in, instantiate=True)

    # casadi_inputs = [ca.MX.sym(f"x{i}") for i in range(len(inputs))]

    # Convert the result to a CasADi MX expression.
    # casadi_expr = interpret_jaxpr(result, *tracers_in)

    # Create a CasADi Function to evaluate the expression.
    # casadi_func = ca.Function('jaxpr_func', casadi_inputs, [casadi_expr])

    #return casadi_func

In [3]:
# Example usage:
def example_function_jax(x, y):
    return jax.numpy.sin(x) + jax.numpy.cos(y)

# Convert the JAX function to a jaxpr.
jaxpr_maker = jax.make_jaxpr(example_function_jax)
jaxpr = jaxpr_maker(jnp.array([1.0]), jnp.array([2.0])).jaxpr

# Define CasADi symbols for the inputs.
casadi_x = cs.MX.sym('x')
casadi_y = cs.MX.sym('y')

# Convert the JAX jaxpr to a CasADi.MX expression.
casadi_expr = jaxpr_to_casadi(jaxpr, [casadi_x, casadi_y])

# Evaluate the CasADi expression.
# result = casadi_expr(1.0, 2.0)

#print(f"Result: {result}")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [114]:
# based on https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def examine_jaxpr_pure(jaxpr):
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  i=0
  for eqn in jaxpr.eqns:
    print(f"equation[{i}]:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
    i=i+1
  print()
  print("jaxpr:", jaxpr)

In [5]:
examine_jaxpr(jaxpr_maker(jnp.array([1.0]), jnp.array([2.0])))


invars: [a, b]
outvars: [e]
constvars: []
equation: [a] sin [c] {}
equation: [b] cos [d] {}
equation: [c, d] add [e] {}

jaxpr: { lambda ; a:f32[1] b:f32[1]. let
    c:f32[1] = sin a
    d:f32[1] = cos b
    e:f32[1] = add c d
  in (e,) }


In [6]:
def example_function_cs(x, y):
    return cs.sin(x) + cs.cos(y)

cs_a = cs.MX.sym('a',1,1)
cs_b = cs.MX.sym('b',1,1)

cs_e = cs.Function("e",[cs_a, cs_b], [example_function_cs(cs_a, cs_b)])

print(cs_e)

e:(i0,i1)->(o0) MXFunction


In [7]:

print(f"Result via direct casadi: {cs_e(1.0, 2.0)}")

Result via direct casadi: 0.425324


In [181]:
def generic_iota(shape, dimension=0, dtype=None):
    retval = np.zeros(shape, dtype=dtype)
    ran = np.arange(shape[dimension], dtype=dtype)
    if dimension == 0:
        for j in range(shape[1]):
            retval[:, j] = ran
    if dimension == 1:
        for j in range(shape[0]):
            retval[j, :] = ran
    return retval

# Matches the definition in https://www.tensorflow.org/xla/operation_semantics
def casadi_iota(shape, dimension=0, dtype=None):
    return cs.MX(generic_iota(shape, dimension, dtype))

def jaxeq_input2casadi(jaxeq_invar, jaxstr2cs):
    if type(jaxeq_invar) == jax._src.core.Literal:
        return cs.MX(jaxeq_invar.val)
    else:
        return jaxstr2cs[str(jaxeq_invar)]

def jaxpr2csmxfunction(jaxpr):
    jaxstr2cs = dict()
    cs_invars = []
    for jaxinvar in jaxpr.invars:
        csinvar = cs.MX.sym(str(jaxinvar), jaxinvar.aval.shape[0])
        jaxstr2cs[str(jaxinvar)] = csinvar
        cs_invars.append(csinvar)
    for jaxeq in jaxpr.eqns:
        print(f"jaxeq.primitive.name: {jaxeq.primitive.name}")
        if jaxeq.primitive.name == 'sin':
            jaxstr2cs[str(jaxeq.outvars[0])] = cs.sin(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs))
        elif jaxeq.primitive.name == 'cos':
            jaxstr2cs[str(jaxeq.outvars[0])] = cs.cos(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs))
        elif jaxeq.primitive.name == 'add':
            jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs) + jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs)
        elif jaxeq.primitive.name == 'iota':
            jaxstr2cs[str(jaxeq.outvars[0])] = casadi_iota(**(jaxeq.params))
        elif jaxeq.primitive.name == 'eq':
            jaxstr2cs[str(jaxeq.outvars[0])] = (jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs) == jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))
        elif jaxeq.primitive.name == 'convert_element_type':
            # We do not distinguish between different types in casadi, so we just copy the input in the output
            jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs)
        elif jaxeq.primitive.name == 'copy':
            jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs)
        elif jaxeq.primitive.name == 'squeeze':
            jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs)
        elif jaxeq.primitive.name == 'reshape':
            # we support just dummy reshape for now
            if (jaxeq.params['new_sizes'] == () and jaxeq.params['dimensions'] is None):
                jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs)
            else:
                raise Exception(f"jaxpr2csmxfunction unsupported arguments for reshape primitive {jaxeq.params}")
        elif jaxeq.primitive.name == 'broadcast_in_dim':
            # we support just dummy broadcast_in_dim for now'shape': , 'broadcast_dimensions': ()}
            if (jaxeq.params['shape'] == (1,) and jaxeq.params['broadcast_dimensions'] == ()):
                jaxstr2cs[str(jaxeq.outvars[0])] = jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs)
            else:
                raise Exception(f"jaxpr2csmxfunction unsupported arguments for broadcast_in_dim primitive {jaxeq.params}")
        elif jaxeq.primitive.name == 'dot_general':
            # dot_general is quite a generic primitive, for now we support just matrix mupliciation, that is when 
            # dimension_numbers is (((1,), (0,)), ((), ()))
            if (jaxeq.params['dimension_numbers'] == (((1,), (0,)), ((), ()))):
                jaxstr2cs[str(jaxeq.outvars[0])] = cs.mtimes(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs),jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))
            else:
                raise Exception(f"jaxpr2csmxfunction unsupported arguments for dot_general primitive {jaxeq.params}")
        elif jaxeq.primitive.name == 'slice':
            if jaxeq.params["strides"] is None:
                param_slice = slice(jaxeq.params["start_indices"][0],jaxeq.params["limit_indices"][0])
            else:
                param_slice = slice(jaxeq.params["start_indices"][0],jaxeq.params["limit_indices"][0],jaxeq.params["strides"][0])
            jaxstr2cs[str(jaxeq.outvars[0])] = jaxstr2cs[str(jaxeq.invars[0])][param_slice]
        elif jaxeq.primitive.name == 'concatenate':
            if jaxeq.params["dimension"]== 0:
                jaxstr2cs[str(jaxeq.outvars[0])] = cs.vertcat(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs),jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))
            else:
                assert(jaxeq.params["dimension"]==1)
                jaxstr2cs[str(jaxeq.outvars[0])] = cs.horzcat(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs),jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))
        elif jaxeq.primitive.name == 'scatter':
            if jaxeq.params["dimension"]== 0:
                jaxstr2cs[str(jaxeq.outvars[0])] = cs.vertcat(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs),jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))
            else:
                assert(jaxeq.params["dimension"]==1)
                jaxstr2cs[str(jaxeq.outvars[0])] = cs.horzcat(jaxeq_input2casadi(jaxeq.invars[0], jaxstr2cs),jaxeq_input2casadi(jaxeq.invars[1], jaxstr2cs))

        else:
            raise Exception(f"jaxpr2csmxfunction unsupported primitive {jaxeq.primitive.name}")
    return cs.Function(str(jaxpr.outvars[0]),cs_invars, [jaxstr2cs[str(jaxpr.outvars[0])]])
        
        

In [182]:
from adam.jax import KinDynComputations
import icub_models
import numpy as np
import scipy
        
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = ['torso_pitch']
# Specify the root link
root_link = 'root_link'
kinDyn = KinDynComputations(model_path, joints_name_list, root_link)
nrOfJoints = len(joints_name_list)
w_H_b = np.eye(4)
joints = np.ones(nrOfJoints)

# Identity
jax_function = kinDyn.forward_kinematics_fun("torso_1")
jax_jaxpr = jax.make_jaxpr(jax_function)(jnp.array(w_H_b), jnp.array(joints)).jaxpr
#csfunc = jaxpr2csmxfunction(jax_jaxpr)
examine_jaxpr_pure(jax_jaxpr)#jax_jaxpr
csfunc = jaxpr2csmxfunction(jax_jaxpr)
type(jax_jaxpr.eqns[19].invars[0])

invars: [g, h]
outvars: [gk]
constvars: [a, b, c, d, e, f]
equation[0]: [] iota [i] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 0}
equation[1]: [i, 0] add [j] {}
equation[2]: [] iota [k] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 1}
equation[3]: [j, k] eq [l] {}
equation[4]: [l] convert_element_type [m] {'new_dtype': dtype('float32'), 'weak_type': False}
equation[5]: [g] copy [n] {}
equation[6]: [m, n] dot_general [o] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation[7]: [h] slice [p] {'start_indices': (0,), 'limit_indices': (1,), 'strides': None}
equation[8]: [p] squeeze [q] {'dimensions': (0,)}
equation[9]: [] iota [r] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 0}
equation[10]: [r, 0] add [s] {}
equation[11]: [] iota [t] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 1}
equation[12]: [s, t] eq [u] {}
equation[13]: [u] convert_element_type [v] {'new_dtype': dtype('float

Exception: jaxpr2csmxfunction unsupported primitive scatter

In [None]:
generic_iota(**(jax_jaxpr.eqns[0].params))

In [None]:
print(f"Result via casadi converted from jax: {jaxpr2csmxfunction(jaxpr)(1.0, 2.0)}")

# Unit test

In [183]:
import unittest
import adam

class TestJaxpr2csmxfunction(unittest.TestCase):

    
    def test_simpl_fun(self):
        def example_function_jax(x, y):
            return jax.numpy.sin(x) + jax.numpy.cos(y)

        # Convert the JAX function to a jaxpr.
        jaxpr_maker = jax.make_jaxpr(example_function_jax)
        jaxpr = jaxpr_maker(jnp.array([1.0]), jnp.array([1.0])).jaxpr
        
        for i in range(0, 1):
            x = np.random.uniform(-20, 20)
            y = np.random.uniform(-20, 20)
            res_via_jax =  example_function_jax(x,y)
            print("Testing simple fun")
            res_via_casadi = jaxpr2csmxfunction(jaxpr)(x, y)
            self.assertAlmostEqual(float(res_via_casadi), float(res_via_jax),places=5)

    def test_adam_fwd_kin(self):
        from adam.jax import KinDynComputations
        import icub_models
        import numpy as np
        import scipy
        
        model_path = icub_models.get_model_file("iCubGazeboV2_5")
        # The joint list
        joints_name_list = ['torso_pitch']
        # Specify the root link
        root_link = 'root_link'
        kinDyn = KinDynComputations(model_path, joints_name_list, root_link)

        nrOfJoints = len(joints_name_list)
        w_H_b = np.eye(4)
        joints = np.ones(nrOfJoints)

        # Identity
        jax_function = kinDyn.forward_kinematics_fun("root_link")
        jax_jaxpr = jax.make_jaxpr(jax_function)(jnp.array(w_H_b), jnp.array(joints)).jaxpr
        csfunc = jaxpr2csmxfunction(jax_jaxpr)

        for i in range(0, 1):
            w_H_b[0:3,3] = np.random.uniform(-20, 20,3)
            w_H_b[0:3,0:3] = scipy.spatial.transform.Rotation.random().as_matrix()
            joints = np.random.uniform(-20, 20, nrOfJoints)
            res_via_jax =  jax_function(w_H_b,joints)
            res_via_casadi = jaxpr2csmxfunction(jax_jaxpr)(w_H_b, joints)
            np.testing.assert_array_almost_equal(np.array(res_via_casadi), np.array(res_via_jax),decimal=5)

        # Actual kin
        jax_function = kinDyn.forward_kinematics_fun("torso_1")
        jax_jaxpr = jax.make_jaxpr(jax_function)(jnp.array(w_H_b), jnp.array(joints)).jaxpr
        csfunc = jaxpr2csmxfunction(jax_jaxpr)

        for i in range(0, 1):
            w_H_b[0:3,3] = np.random.uniform(-20, 20,3)
            w_H_b[0:3,0:3] = scipy.spatial.transform.Rotation.random().as_matrix()
            joints = np.random.uniform(-20, 20, nrOfJoints)
            res_via_jax =  jax_function(w_H_b,joints)
            res_via_casadi = jaxpr2csmxfunction(jax_jaxpr)(w_H_b, joints)
            np.testing.assert_array_almost_equal(np.array(res_via_casadi), np.array(res_via_jax),decimal=5)
        

unittest.main(argv=[''], verbosity=2, exit=False)

test_adam_fwd_kin (__main__.TestJaxpr2csmxfunction.test_adam_fwd_kin) ... ERROR
test_simpl_fun (__main__.TestJaxpr2csmxfunction.test_simpl_fun) ... ok

ERROR: test_adam_fwd_kin (__main__.TestJaxpr2csmxfunction.test_adam_fwd_kin)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_4078/1173698699.py", line 56, in test_adam_fwd_kin
    csfunc = jaxpr2csmxfunction(jax_jaxpr)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_4078/2555924633.py", line 80, in jaxpr2csmxfunction
    raise Exception(f"jaxpr2csmxfunction unsupported primitive {jaxeq.primitive.name}")
Exception: jaxpr2csmxfunction unsupported primitive scatter

----------------------------------------------------------------------
Ran 2 tests in 0.142s

FAILED (errors=1)


jaxeq.primitive.name: iota
jaxeq.primitive.name: add
jaxeq.primitive.name: iota
jaxeq.primitive.name: eq
jaxeq.primitive.name: convert_element_type
jaxeq.primitive.name: copy
jaxeq.primitive.name: dot_general
jaxeq.primitive.name: iota
jaxeq.primitive.name: add
jaxeq.primitive.name: iota
jaxeq.primitive.name: eq
jaxeq.primitive.name: convert_element_type
jaxeq.primitive.name: copy
jaxeq.primitive.name: dot_general
jaxeq.primitive.name: iota
jaxeq.primitive.name: add
jaxeq.primitive.name: iota
jaxeq.primitive.name: eq
jaxeq.primitive.name: convert_element_type
jaxeq.primitive.name: copy
jaxeq.primitive.name: dot_general
jaxeq.primitive.name: slice
jaxeq.primitive.name: squeeze
jaxeq.primitive.name: iota
jaxeq.primitive.name: add
jaxeq.primitive.name: iota
jaxeq.primitive.name: eq
jaxeq.primitive.name: convert_element_type
jaxeq.primitive.name: iota
jaxeq.primitive.name: add
jaxeq.primitive.name: iota
jaxeq.primitive.name: eq
jaxeq.primitive.name: convert_element_type
jaxeq.primitive.nam

<unittest.main.TestProgram at 0x7f82e94c6a80>

In [11]:
x = np.random.uniform()
y = np.random.uniform()
jax_env = {'x': np.array(x), 'y': np.array(y)}  # Specify values for variables
res_via_jax =  example_function_jax(x,y)
res_via_casadi = jaxpr2csmxfunction(jaxpr)(x, y)
print(float(res_via_jax))
print(float(res_via_casadi))

1.3426449298858643
1.3426449731217036


## Test with adam-robotics

In [12]:
import adam
from adam.jax import KinDynComputations
import icub_models
import numpy as np

# if you want to icub-models https://github.com/robotology/icub-models to retrieve the urdf
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
    'torso_pitch'
]
# Specify the root link
root_link = 'root_link'
kinDyn = KinDynComputations(model_path, joints_name_list, root_link)
# or, if you want to use the mixed representation (that is the default)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
w_H_b = np.eye(4)
w_H_b[0:3,3] = [1,2,3]
print(w_H_b)
joints = np.ones(len(joints_name_list))
M = kinDyn.mass_matrix(w_H_b, joints)
print(M)

[[1. 0. 0. 1.]
 [0. 1. 0. 2.]
 [0. 0. 1. 3.]
 [0. 0. 0. 1.]]
[[ 3.3061672e+01 -1.3819033e-08 -9.5367432e-07 -1.1030352e-08
  -3.2419498e+00  1.3196170e-03 -1.2740911e+00]
 [ 5.3891540e-09  3.3061672e+01 -2.4181674e-08  3.2419500e+00
   1.6108035e-09 -1.3785756e+00  5.4081486e-08]
 [-9.5367432e-07 -2.0859062e-08  3.3061672e+01 -1.3196468e-03
   1.3785757e+00  7.4796844e-09 -1.5338850e+00]
 [-9.4005230e-09  3.2419500e+00 -1.3196766e-03  2.4266503e+00
   2.1983148e-04  2.2242114e-01 -2.3635932e-04]
 [-3.2419500e+00 -1.7974777e-08  1.3785754e+00  2.1983532e-04
   2.5952973e+00 -6.4104795e-05 -6.7038065e-01]
 [ 1.3196170e-03 -1.3785756e+00  4.2491592e-08  2.2242114e-01
  -6.4074993e-05  6.6768861e-01 -2.2677788e-04]
 [-1.2740911e+00  5.4081486e-08 -1.5338850e+00 -2.3635932e-04
  -6.7038065e-01 -2.2677788e-04  6.7038065e-01]]


In [13]:
H = kinDyn.forward_kinematics_fun(frame="l_hand")



In [14]:
print(w_H_b)

[[1. 0. 0. 1.]
 [0. 1. 0. 2.]
 [0. 0. 1. 3.]
 [0. 0. 0. 1.]]


In [15]:
H(w_H_b,joints)


Array([[ 9.5263904e-01, -8.0751107e-08, -3.0410326e-01,  1.1005967e+00],
       [-1.0949959e-07, -1.0000000e+00, -9.9345080e-08,  1.9101018e+00],
       [-3.0410320e-01,  8.5248118e-08, -9.5263904e-01,  2.9224308e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00]],      dtype=float32)

In [16]:
examine_jaxpr(jax.make_jaxpr(H)(jnp.array(w_H_b), jnp.array(joints)))


invars: [cj, ck]
outvars: [cnx]
constvars: [a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, ba, bb, bc, bd, be, bf, bg, bh, bi, bj, bk, bl, bm, bn, bo, bp, bq, br, bs, bt, bu, bv, bw, bx, by, bz, ca, cb, cc, cd, ce, cf, cg, ch, ci]
equation: [] iota [cl] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 0}
equation: [cl, 0] add [cm] {}
equation: [] iota [cn] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 1}
equation: [cm, cn] eq [co] {}
equation: [co] convert_element_type [cp] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [cj] copy [cq] {}
equation: [cp, cq] dot_general [cr] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [ck] slice [cs] {'start_indices': (0,), 'limit_indices': (1,), 'strides': None}
equation: [cs] squeeze [ct] {'dimensions': (0,)}
equation: [] iota [cu] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 0}
equation: [cu, 0] add [cv

In [17]:
# Extract all the primitives used in a jax[r
def extract_primitives_fromjaxpr(jaxpr):
    all_primitives = set()
    for jaxeq in jaxpr.eqns:
        all_primitives.add(jaxeq.primitive.name)
    return all_primitives

In [18]:
extract_primitives_fromjaxpr(jax.make_jaxpr(H)(jnp.array(w_H_b), jnp.array(joints)))

{'add',
 'broadcast_in_dim',
 'concatenate',
 'convert_element_type',
 'copy',
 'cos',
 'dot_general',
 'eq',
 'iota',
 'mul',
 'neg',
 'pjit',
 'reshape',
 'scatter',
 'sin',
 'slice',
 'squeeze',
 'sub'}

In [19]:
# Try to extract the simplest jaxpr, a function that is an identity


In [20]:
simple_mat_identity = jax.make_jaxpr(kinDyn.forward_kinematics_fun("root_link"))(jnp.array(w_H_b), jnp.array(joints))
examine_jaxpr(simple_mat_identity)
extract_primitives_fromjaxpr(simple_mat_identity)

invars: [a, b]
outvars: [i]
constvars: []
equation: [] iota [c] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 0}
equation: [c, 0] add [d] {}
equation: [] iota [e] {'dtype': dtype('int32'), 'shape': (4, 4), 'dimension': 1}
equation: [d, e] eq [f] {}
equation: [f] convert_element_type [g] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [a] copy [h] {}
equation: [g, h] dot_general [i] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}

jaxpr: { lambda ; a:f32[4,4] b:f32[1]. let
    c:i32[4,4] = iota[dimension=0 dtype=int32 shape=(4, 4)] 
    d:i32[4,4] = add c 0
    e:i32[4,4] = iota[dimension=1 dtype=int32 shape=(4, 4)] 
    f:bool[4,4] = eq d e
    g:f32[4,4] = convert_element_type[new_dtype=float32 weak_type=False] f
    h:f32[4,4] = copy a
    i:f32[4,4] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] g h
  in (i,) }


{'add', 'convert_element_type', 'copy', 'dot_general', 'eq', 'iota'}

In [21]:
type(simple_mat_identity.eqns[0].outvars[0].aval)

jax._src.core.ShapedArray

In [22]:
help(jax.lax.iota)

Help on function iota in module jax._src.lax.lax:

iota(dtype: 'DTypeLike', size: 'int') -> 'Array'
    Wraps XLA's `Iota
    <https://www.tensorflow.org/xla/operation_semantics#iota>`_
    operator.



In [23]:
jax.lax.iota(dtype=int, size=4)

Array([0, 1, 2, 3], dtype=int32)

In [24]:
help(jax._src.lax.lax.iota)


Help on function iota in module jax._src.lax.lax:

iota(dtype: 'DTypeLike', size: 'int') -> 'Array'
    Wraps XLA's `Iota
    <https://www.tensorflow.org/xla/operation_semantics#iota>`_
    operator.



In [25]:
import itertools

# From https://github.com/google/jax/blob/ffb115bf2ee50d2553c3f0bbc149eb2933944546/jax/_src/lax_reference.py#L209
def dot_general(lhs, rhs, dimension_numbers):
  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
  new_id = itertools.count()
  lhs_axis_ids = [next(new_id) for _ in lhs.shape]
  rhs_axis_ids = [next(new_id) for _ in rhs.shape]
  lhs_out_axis_ids = lhs_axis_ids[:]
  rhs_out_axis_ids = rhs_axis_ids[:]

  for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
    shared_id = next(new_id)
    lhs_axis_ids[lhs_axis] = shared_id
    rhs_axis_ids[rhs_axis] = shared_id
    lhs_out_axis_ids[lhs_axis] = None
    rhs_out_axis_ids[rhs_axis] = None

  batch_ids = []
  for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
    shared_id = next(new_id)
    lhs_axis_ids[lhs_axis] = shared_id
    rhs_axis_ids[rhs_axis] = shared_id
    lhs_out_axis_ids[lhs_axis] = None
    rhs_out_axis_ids[rhs_axis] = None
    batch_ids.append(shared_id)

  not_none = lambda x: x is not None
  out_axis_ids = filter(not_none,
                        batch_ids + lhs_out_axis_ids + rhs_out_axis_ids)
  assert lhs.dtype == rhs.dtype
  # dtype = np.float32 if lhs.dtype == dtypes.bfloat16 else None
  dtype = None
  out = np.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids,
                   dtype=dtype)
  #return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out
  return out

# Matches the definition in https://www.tensorflow.org/xla/operation_semantics
def iota(shape, dimension=0, dtype=None):
    retval = np.zeros(shape, dtype=dtype)
    ran = np.arange(shape[dimension], dtype=dtype)
    if dimension == 0:
        for j in range(shape[1]):
            retval[:, j] = ran
    if dimension == 1:
        for j in range(shape[0]):
            retval[j, :] = ran
    return retval

# Test https://www.tensorflow.org/xla/operation_semantics
w_H_b = np.eye(4)
w_H_b[0:3,3] = [1,2,3]
a = w_H_b.astype(np.float32)
desired_shape = (4,4)
c = iota(shape=desired_shape, dimension=0)
d = c + 0
e = iota(shape=desired_shape, dimension=1)
f = (d == e)
g = np.float32(f)
h = a 
print(f"h={h}")
print(f"g={g}")
print(f"h.dtype={h.dtype}")
print(f"gdtype={g.dtype}")
i = dot_general(g,h,dimension_numbers=(([1], [0]), ([], [])))
print(f"i={i}")




h=[[1. 0. 0. 1.]
 [0. 1. 0. 2.]
 [0. 0. 1. 3.]
 [0. 0. 0. 1.]]
g=[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
h.dtype=float32
gdtype=float32
i=[[1. 0. 0. 1.]
 [0. 1. 0. 2.]
 [0. 0. 1. 3.]
 [0. 0. 0. 1.]]
