In [1]:
# This example successfully demonstrates the conversion of JAX functions to TensorFlow and subsequently to ONNX, 
# using TensorFlow's tf2onnx tool for ONNX export and ONNX Runtime for inference. 

from jax.experimental import jax2tf
from jax import numpy as jnp
from jax import tree_util as jtu, vmap, jit, lax, nn
from jax.experimental import sparse

from pymdp.jax.agent import Agent

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from functools import partial
from typing import Optional, Tuple, List
import time

import onnx
import tf2onnx
import onnxruntime as ort
from onnxsim import simplify
import netron

from opt_einsum import contract


In [2]:

MINVAL = jnp.finfo(float).eps

def log_stable(x):
    return jnp.log(jnp.clip(x, min=MINVAL))

@partial(jit, static_argnames=['keep_dims'])
def factor_dot(M, xs, keep_dims: Optional[Tuple[int]] = None):
    """ Dot product of a multidimensional array with `x`.
    
    Parameters
    ----------
    - `qs` [list of 1D numpy.ndarray] - list of jnp.ndarrays
    
    Returns 
    -------
    - `Y` [1D numpy.ndarray] - the result of the dot product
    """
    d = len(keep_dims) if keep_dims is not None else 0
    assert M.ndim == len(xs) + d
    keep_dims = () if keep_dims is None else keep_dims
    dims = tuple((i,) for i in range(M.ndim) if i not in keep_dims)
    return factor_dot_flex(M, xs, dims, keep_dims=keep_dims)

@partial(jit, static_argnames=['dims', 'keep_dims'])
def factor_dot_flex(M, xs, dims: List[Tuple[int]], keep_dims: Optional[Tuple[int]] = None):
    """ Dot product of a multidimensional array with `x`.
    
    Parameters
    ----------
    - `M` [numpy.ndarray] - tensor
    - 'xs' [list of numpyr.ndarray] - list of tensors
    - 'dims' [list of tuples] - list of dimensions of xs tensors in tensor M
    - 'keep_dims' [tuple] - tuple of integers denoting dimesions to keep
    Returns 
    -------
    - `Y` [1D numpy.ndarray] - the result of the dot product
    """
    all_dims = tuple(range(M.ndim))
    matrix = [[xs[f], dims[f]] for f in range(len(xs))]
    args = [M, all_dims]
    for row in matrix:
        args.extend(row)

    args += [keep_dims]
    return contract(*args, backend='jax')

def get_likelihood_single_modality(o_m, A_m, distr_obs=True):
    """Return observation likelihood for a single observation modality m"""
    if distr_obs:
        expanded_obs = jnp.expand_dims(o_m, tuple(range(1, A_m.ndim)))
        likelihood = (expanded_obs * A_m).sum(axis=0)
    else:
        likelihood = A_m[o_m]

    return likelihood

def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True):
    """Compute observation log-likelihood for a single modality"""
    return log_stable(get_likelihood_single_modality(o_m, A_m, distr_obs=distr_obs))

def compute_log_likelihood_per_modality(obs, A, distr_obs=True):
    """ Compute likelihood over hidden states across observations from different modalities, and return them per modality """
    ll_all = jtu.tree_map(lambda o, a: compute_log_likelihood_single_modality(o, a, distr_obs=distr_obs), obs, A)

    return ll_all
    
def marginal_log_likelihood(qs, log_likelihood, i):
    #print("thereee", log_likelihood.shape)
    xs = [q for j, q in enumerate(qs) if j != i]
    return factor_dot(log_likelihood, xs, keep_dims=(i,))

def mll_factors(qs, ll_m, factor_list_m) -> List:
    factor_list_m = [0,1]
    relevant_factors = [qs[f] for f in factor_list_m]
    marginal_ll_f = jtu.Partial(marginal_log_likelihood, relevant_factors, ll_m)
    loc_nf = len(factor_list_m)
    loc_factors = list(range(loc_nf))
    return jtu.tree_map(marginal_ll_f, loc_factors)

def all_marginal_log_likelihood(qs, log_likelihoods, all_factor_lists):
    qL_marginals = jtu.tree_map(lambda ll_m, factor_list_m: mll_factors(qs, ll_m, factor_list_m), log_likelihoods, all_factor_lists)
    num_factors = len(qs)

    # instead of a double loop we could have a list defining m to f mapping
    # which could be resolved with a single tree_map cast
    qL_all = [jnp.zeros(1)] * num_factors
    for m, factor_list_m in enumerate(all_factor_lists):
        factor_list_m = [0, 1]
        for l, f in enumerate(factor_list_m):
            qL_all[f] += qL_marginals[m][l]

    return qL_all

def add(x, y):
    return x + y

xs = jnp.arange(16)

def run_factorized_fpi(A, obs, prior, A_dependencies, num_iter=1):
    """
    Run the fixed point iteration algorithm with sparse dependencies between factors and outcomes (stored in `A_dependencies`)
    """

    # Step 1: Compute log likelihoods for each factor
    log_likelihoods = compute_log_likelihood_per_modality(obs, A)

    # Step 2: Map prior to log space and create initial log-posterior
    log_prior = jtu.tree_map(log_stable, prior)
    log_q = jtu.tree_map(jnp.zeros_like, prior)

    # Step 3: Iterate until convergence
    def scan_fn(carry, t):
        log_q = carry
        q = jtu.tree_map(nn.softmax, log_q)
        marginal_ll = all_marginal_log_likelihood(q, log_likelihoods, A_dependencies)
        log_q = jtu.tree_map(add, marginal_ll, log_prior)

        return log_q, None

    res, _ = lax.scan(scan_fn, log_q, xs)

    # Step 4: Map result to factorised posterior
    qs = jtu.tree_map(nn.softmax, res)
    return qs

def top_function(A, obs, prior, A_dependencies, num_iter=16):
    infer_states = partial(
        run_factorized_fpi,
        A_dependencies=A_dependencies,
        num_iter=num_iter
    )
    
    output = vmap(infer_states)(
        A,
        obs,
        prior
    )
    
    return output

In [3]:
num_states = [3, 2]
num_obs = [2]
n_batch = 2

A_1 = jnp.array([[1.0, 1.0, 1.0], [0.0,  0.0,  1.]])
A_2 = jnp.array([[1.0, 1.0], [1., 0.]])

A_tensor = A_1[..., None] * A_2[:, None]

A_tensor /= A_tensor.sum(0)

A = [jnp.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2)) ]

# create two transition matrices, one for each state factor
B_1 = jnp.broadcast_to(
    jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3)
)

B_2 = jnp.broadcast_to(
        jnp.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2)
    )

B = [B_1[..., None], B_2[..., None]]

# for the single modality, a sequence over time of observations (one hot vectors)
obs = [jnp.broadcast_to(jnp.array([[1., 0.], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors 
C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences
D = [jnp.ones((n_batch, 3)) / 3., jnp.ones((n_batch, 2)) / 2.] # flat prior
E = jnp.ones((n_batch, 1))

pA = None
pB = None

agent = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=pA,
        pB=pB,
        policy_len=3,
        onehot_obs=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="ovf",
        num_iter=16
)


prior = agent.D
action_hist = []
qs_hist=None
first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:2], 0, 1), obs)
curr_obs = jtu.tree_map(lambda x: x[-1], first_obs)


infer_states = partial(
        run_factorized_fpi,
        A_dependencies=agent.A_dependencies,
        num_iter=16
    )
    
output = vmap(infer_states)(
        agent.A,
        curr_obs,
        prior
    )

print("Inputs")
print("A", agent.A)
print("obs", curr_obs)
print("D", prior)
print("A_dep", agent.A_dependencies)

print("output")
print(output)

#[Array([[0.36628395, 0.36628395, 0.2674321 ],
#       [0.36628395, 0.36628395, 0.2674321 ]], dtype=float32), Array([[0.45378983, 0.5462102 ],
#       [0.45378983, 0.5462102 ]], dtype=float32)]




Inputs
A [Array([[[[1. , 1. ],
         [1. , 1. ],
         [0.5, 1. ]],

        [[0. , 0. ],
         [0. , 0. ],
         [0.5, 0. ]]],


       [[[1. , 1. ],
         [1. , 1. ],
         [0.5, 1. ]],

        [[0. , 0. ],
         [0. , 0. ],
         [0.5, 0. ]]]], dtype=float32)]
obs [Array([[1., 0.],
       [1., 0.]], dtype=float32)]
D [Array([[0.33333334, 0.33333334, 0.33333334],
       [0.33333334, 0.33333334, 0.33333334]], dtype=float32), Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32)]
A_dep [[0, 1]]
output
[Array([[0.36628395, 0.36628395, 0.2674321 ],
       [0.36628395, 0.36628395, 0.2674321 ]], dtype=float32), Array([[0.45378983, 0.5462102 ],
       [0.45378983, 0.5462102 ]], dtype=float32)]


In [4]:
# Run the model 1k times and measure the latency
latencies = []
outputs = []
for _ in range(1000):
    start_time = time.time()
    output = vmap(infer_states)(
        agent.A,
        curr_obs,
        prior
    )
    #outputs.append(output)
    latency = time.time() - start_time
    latencies.append(latency)

# Calculate average latency and standard deviation
average_latency = np.mean(latencies)
std_latency = np.std(latencies)

print(f"Average latency: {average_latency * 1000:.2f} ms")
print(f"Standard deviation: {std_latency * 1000:.2f} ms")
print(outputs)

Average latency: 36.29 ms
Standard deviation: 11.98 ms
[]


In [6]:
jf = jit(top_function)
# Compile (warm up)
jf(agent.A, curr_obs, prior, agent.A_dependencies, 16)

latencies = []
outputs = []
for _ in range(100000):
    start_time = time.time()
    y = jf(agent.A, curr_obs, prior, agent.A_dependencies, 16)
    # If y is a list, block until all elements are ready
    if isinstance(y, list):
        for elem in y:
            elem.block_until_ready()
    else:
        y.block_until_ready()
    #outputs.append(output)
    latency = time.time() - start_time
    latencies.append(latency)

# Calculate average latency and standard deviation
average_latency = np.mean(latencies)
std_latency = np.std(latencies)

print(f"Average latency: {average_latency * 1000:.2f} ms")
print(f"Standard deviation: {std_latency * 1000:.2f} ms")
print(outputs)

Average latency: 0.44 ms
Standard deviation: 0.72 ms
[]
