In [1]:
# This example successfully demonstrates the conversion of JAX functions to TensorFlow and subsequently to ONNX, 
# using TensorFlow's tf2onnx tool for ONNX export and ONNX Runtime for inference. 

from jax.experimental import jax2tf
from jax import numpy as jnp
from jax import tree_util as jtu, vmap, jit, lax, nn
from jax.experimental import sparse

from pymdp.jax.agent import Agent

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from functools import partial
from typing import Optional, Tuple, List

import onnx
import tf2onnx
import onnxruntime as ort
import netron
import time

from opt_einsum import contract


In [10]:
MINVAL = jnp.finfo(float).eps

def log_stable(x):
    return jnp.log(jnp.clip(x, min=MINVAL))

def add(x, y):
    return x + y

def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True):
    """ Compute observation likelihood for a single modality (observation and likelihood)"""
    if distr_obs: 
        expanded_obs = jnp.expand_dims(o_m, tuple(range(1, A_m.ndim)))
        likelihood = (expanded_obs * A_m).sum(axis=0)
    else: 
        likelihood = A_m[o_m] 
    
    return log_stable(likelihood)

def compute_log_likelihood_per_modality(obs, A, distr_obs=True):
    """ Compute likelihood over hidden states across observations from different modalities, and return them per modality """
    ll_all = jtu.tree_map(lambda o, a: compute_log_likelihood_single_modality(o, a, distr_obs=distr_obs), obs, A)
            
    return ll_all

def marginal_log_likelihood(qs, log_likelihood, i):
    xs = [q for j, q in enumerate(qs) if j != i]
    return factor_dot(log_likelihood, xs, keep_dims=(i,))


@partial(jit, static_argnames=['keep_dims'])
def factor_dot(M, xs, keep_dims: Optional[Tuple[int]] = None):
    """ Dot product of a multidimensional array with `x`.
    
    Parameters
    ----------
    - `qs` [list of 1D numpy.ndarray] - list of jnp.ndarrays
    
    Returns 
    -------
    - `Y` [1D numpy.ndarray] - the result of the dot product
    """
    d = len(keep_dims) if keep_dims is not None else 0
    assert M.ndim == len(xs) + d
    keep_dims = () if keep_dims is None else keep_dims
    dims = tuple((i,) for i in range(M.ndim) if i not in keep_dims)
    return factor_dot_flex(M, xs, dims, keep_dims=keep_dims)

@partial(jit, static_argnames=['dims', 'keep_dims'])
def factor_dot_flex(M, xs, dims: List[Tuple[int]], keep_dims: Optional[Tuple[int]] = None):
    """ Dot product of a multidimensional array with `x`.
    
    Parameters
    ----------
    - `M` [numpy.ndarray] - tensor
    - 'xs' [list of numpyr.ndarray] - list of tensors
    - 'dims' [list of tuples] - list of dimensions of xs tensors in tensor M
    - 'keep_dims' [tuple] - tuple of integers denoting dimesions to keep
    Returns 
    -------
    - `Y` [1D numpy.ndarray] - the result of the dot product
    """
    all_dims = tuple(range(M.ndim))
    matrix = [[xs[f], dims[f]] for f in range(len(xs))]
    args = [M, all_dims]
    for row in matrix:
        args.extend(row)

    args += [keep_dims]
    return contract(*args, backend='jax')

def mll_factors(qs, ll_m, factor_list_m) -> List:
    relevant_factors = [qs[f] for f in factor_list_m]
        
    marginal_ll_f = jtu.Partial(marginal_log_likelihood, relevant_factors, ll_m)
    loc_nf = len(factor_list_m)
    loc_factors = list(range(loc_nf))
    return jtu.tree_map(marginal_ll_f, loc_factors)


def all_marginal_log_likelihood(qs, log_likelihoods, all_factor_lists):
    qL_marginals = jtu.tree_map(lambda ll_m, factor_list_m: mll_factors(qs, ll_m, factor_list_m), log_likelihoods, all_factor_lists)
    
    num_factors = len(qs)

    # instead of a double loop we could have a list defining m to f mapping
    # which could be resolved with a single tree_map cast
    qL_all = [jnp.zeros(1)] * num_factors
    for m, factor_list_m in enumerate(all_factor_lists):
        for l, f in enumerate(factor_list_m):
            qL_all[f] += qL_marginals[m][l]

    return qL_all

def run_factorized_fpi(A, obs, prior, A_dependencies, num_iter=1):
    """
    Run the fixed point iteration algorithm with sparse dependencies between factors and outcomes (stored in `A_dependencies`)
    """
    # Step 1: Compute log likelihoods for each factor
    ############# ad-hoc adaptation convert log_likelihoods as a list
    log_likelihoods = [compute_log_likelihood_per_modality(obs, A)]
    
    # Step 2: Map prior to log space and create initial log-posterior
    log_prior = jtu.tree_map(log_stable, prior)
    log_q = jtu.tree_map(jnp.zeros_like, prior)

    # Step 3: Iterate until convergence
    def scan_fn(carry, t):
        log_q = carry
        q = jtu.tree_map(nn.softmax, log_q)
        marginal_ll = all_marginal_log_likelihood(q, log_likelihoods, A_dependencies)
        log_q = jtu.tree_map(add, marginal_ll, log_prior)

        return log_q, None
   
    res, _ = lax.scan(scan_fn, log_q, jnp.zeros(16))

    # Step 4: Map result to factorised posterior
    qs = jtu.tree_map(nn.softmax, res)
    return qs

def top_function(A, obs, prior):# A_dependencies, num_iter=1):
    infer_states = partial(
        run_factorized_fpi,
        A_dependencies=[[0,1]], ## ad-hoc traced arrays give problem when indexing lists
        num_iter=16 ## ad-hoc
    )
    
    output = vmap(infer_states)(
    A,
    obs,
    prior
    )
    
    return output


In [11]:
# Convert JAX function to TF
top_function_tf = jax2tf.convert(top_function, enable_xla=False)
# Create a TF graph out of the TF function
function_graph = tf.function(top_function_tf, autograph=False)

# Export the model to ONNX
input_signature = [
        tf.TensorSpec(shape=(2, 2, 3, 2), dtype=tf.float32), # A tensor
        #[tf.TensorSpec(shape=(2, 3, 2), dtype=tf.float32)], # A tensor
        tf.TensorSpec(shape=(2, 2), dtype=tf.float32),       # obs
        [                                                    # Prior
            tf.TensorSpec(shape=[2, 3], dtype=tf.float32),
            tf.TensorSpec(shape=[2, 2], dtype=tf.float32)
        ]
        #[tf.TensorSpec(shape=[2], dtype=tf.int32)],      # A_dependencies
        #tf.TensorSpec(shape=(1), dtype=tf.bool)           # Iter
    ]

# Export the module to ONNX
inference_onnx, _ = tf2onnx.convert.from_function(
    function_graph,
    input_signature=input_signature,
    opset=13
)

onnx_model_path = "trial.onnx"
with open(onnx_model_path, "wb") as f:
    f.write(inference_onnx.SerializeToString())

from onnxsim import simplify

# Load your ONNX model
model = onnx.load(onnx_model_path)

# Simplify the model
model_simp, check = simplify(model)

# Ensure the simplified model is valid
assert check, "Simplified ONNX model could not be validated"

# Save the simplified model
simplified_model_path = "path_to_simplified_model.onnx"
onnx.save(model_simp, simplified_model_path)

print(f"Simplified model saved at: {simplified_model_path}")

netron.start(simplified_model_path)

Simplified model saved at: path_to_simplified_model.onnx
Serving 'path_to_simplified_model.onnx' at http://localhost:21771


('localhost', 21771)

In [4]:

num_states = [3, 2]
num_obs = [2]
n_batch = 2

A_1 = np.array([[1.0, 1.0, 1.0], [0.0,  0.0,  1.]])
A_2 = np.array([[1.0, 1.0], [1., 0.]])

A_tensor = A_1[..., None] * A_2[:, None]

A_tensor /= A_tensor.sum(0)

# Add-hoc, not list
A = np.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2))

# create two transition matrices, one for each state factor
B_1 = np.broadcast_to(
    np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3)
)

B_2 = np.broadcast_to(
        np.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2)
    )

B = [B_1[..., None], B_2[..., None]]

# for the single modality, a sequence over time of observations (one hot vectors)
obs = [np.broadcast_to(np.array([[1., 0.], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors 
C = [np.zeros((n_batch, num_obs[0]))] # flat preferences
D = [np.ones((n_batch, 3)) / 3., np.ones((n_batch, 2)) / 2.] # flat prior
E = np.ones((n_batch, 1))

curr_obs = jtu.tree_map(lambda x: np.moveaxis(x[:1], 0, 1), obs)
        

In [5]:
# Load the ONNX model with ONNX Runtime
ort_session = ort.InferenceSession(simplified_model_path)

# Get model input details
input_names = [input.name for input in ort_session.get_inputs()]

# Prepare the inputs
# Assuming the model expects two inputs, input1 and input2
# Create input dictionary
obs = np.squeeze(curr_obs[0], axis=1)

inputs = {
    input_names[0]: A.astype(np.float32),
    input_names[1]: np.squeeze(curr_obs[0], axis=1).astype(np.float32),
    input_names[2]: D[0].astype(np.float32),
    input_names[3]: D[1].astype(np.float32)
    #input_names[4]: [[0, 1]],
    #input_names[5]: 1
}

# Run the model
output_names = [output.name for output in ort_session.get_outputs()]
outputs = ort_session.run(output_names, inputs)

print("Inputs")
print("A", A.astype(np.float32))
print("Obs", np.squeeze(curr_obs[0], axis=1).astype(np.float32))
print("Prior 0", D[0].astype(np.float32))
print("Prior 1", D[1].astype(np.float32))

print("Output from ONNX Runtime:")
print(outputs)

Inputs
A [[[[1.  1. ]
   [1.  1. ]
   [0.5 1. ]]

  [[0.  0. ]
   [0.  0. ]
   [0.5 0. ]]]


 [[[1.  1. ]
   [1.  1. ]
   [0.5 1. ]]

  [[0.  0. ]
   [0.  0. ]
   [0.5 0. ]]]]
Obs [[1. 0.]
 [1. 0.]]
Prior 0 [[0.33333334 0.33333334 0.33333334]
 [0.33333334 0.33333334 0.33333334]]
Prior 1 [[0.5 0.5]
 [0.5 0.5]]
Output from ONNX Runtime:
[array([[0.36628395, 0.36628395, 0.2674321 ],
       [0.36628395, 0.36628395, 0.2674321 ]], dtype=float32), array([[0.45378983, 0.5462102 ],
       [0.45378983, 0.5462102 ]], dtype=float32)]


[0;93m2024-07-31 11:14:11.688623 [W:onnxruntime:, graph.cc:4093 CleanUnusedInitializersAndNodeArgs] Removing initializer '_v_92'. It is not used by any node and should be removed from the model.[m
[0;93m2024-07-31 11:14:11.688699 [W:onnxruntime:, graph.cc:4093 CleanUnusedInitializersAndNodeArgs] Removing initializer '_v_90'. It is not used by any node and should be removed from the model.[m
[0;93m2024-07-31 11:14:11.688717 [W:onnxruntime:, graph.cc:4093 CleanUnusedInitializersAndNodeArgs] Removing initializer 'Reshape__96_shape__111'. It is not used by any node and should be removed from the model.[m
[0;93m2024-07-31 11:14:11.688746 [W:onnxruntime:, graph.cc:4093 CleanUnusedInitializersAndNodeArgs] Removing initializer 'Reshape__97_shape__112'. It is not used by any node and should be removed from the model.[m
[0;93m2024-07-31 11:14:11.688766 [W:onnxruntime:, graph.cc:4093 CleanUnusedInitializersAndNodeArgs] Removing initializer 'const_fold_opt__70'. It is not used by any node

In [6]:
# Run the model 1k times and measure the latency
latencies = []
outputs = []
for _ in range(100000):
    start_time = time.time()
    output = ort_session.run(output_names, inputs)
    #outputs.append(output)
    
    latency = time.time() - start_time
    latencies.append(latency)

# Calculate average latency and standard deviation
average_latency = np.mean(latencies)
std_latency = np.std(latencies)

print(f"Average latency: {average_latency * 1000:.2f} ms")
print(f"Standard deviation: {std_latency * 1000:.2f} ms")
print(outputs)

Average latency: 1.70 ms
Standard deviation: 1.26 ms
[]
