# Calling External JAX Functions in JaxPlan.

This preliminary notebook discusses how to call external jax compiled function calls from within RDDL domain description files.

In [1]:
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Import the required packages:

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import jax
import haiku as hk

import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
from rddlrepository.core.manager import RDDLRepoManager

Let us now define the RDDL domain and instance description files containing the code with a function call. We will define a 1D convolution layer with weights ``W`` and biases ``b`` to process the state variable: 

In [3]:
domain_text = """
domain my_domain {

    types {
        obj : object;
        kernel : object;
    };

    pvariables {
        state(obj) : { state-fluent, real, default = 0.0 };
        action(obj) : { action-fluent, real, default = 0.0 };

        W(kernel, obj, obj) : { non-fluent, real, default = 0.2 };
        b(obj) : { non-fluent, real, default = 0.0 };
    };
    
    cpfs {
        state'(?o) = $CNNFunction[?o](state(_), W(_, _, _), b(_));
    };
    
    reward = sum_{?o : obj} pow[state'(?o) + action(?o) - 4, 2];
    
    action-preconditions {
        forall_{?o : obj} [action(?o) >= -10 ^ action(?o) <= 10];
    };
}
"""

instance_text = """
non-fluents nf_simple {
    domain = my_domain;
    objects {
        obj : { o1, o2, o3, o4 };
        kernel : { k1, k2, k3 };
    };
}

instance simple_inst {
    domain = my_domain;
    non-fluents = nf_simple;

    init-state {
        state(o1) = 1.0;
        state(o2) = 2.0;
        state(o3) = 3.0;
        state(o4) = 4.0;
    };

    max-nondef-actions = pos-inf;
    horizon = 5;
    discount = 1.0;
}
"""

# register the domain and instance with rddlrepository
manager = RDDLRepoManager(rebuild=True)
manager.register_domain("ExternalFuncDomainCNN", "standalone", domain_text, desc="domain with CNN", viz=None)
problem_info = manager.get_problem("ExternalFuncDomainCNN_standalone")
problem_info.register_instance("1", instance_text)
RDDLRepoManager(rebuild=True)

Domain <ExternalFuncDomainCNN> was successfully registered in rddlrepository with context <standalone>.
Instance <1> was successfully registered in rddlrepository for domain <ExternalFuncDomainCNN_standalone>.


<rddlrepository.core.manager.RDDLRepoManager at 0x2e760bcf920>

The line of the code `state'(?o) = $CNNFunction[?o](state(_), ...);` calls an external JAX compiled function that applies a 1D convolutional neural network to the current state.

Next, we must define the external function, which in this case defines the CNN in jax and haiku; note the function must be JIT compilable in order to use JaxPlan:

In [4]:
class CNN(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.conv1d_layer = hk.Conv1D(
            output_channels=4,
            kernel_shape=3,
            padding="SAME",
            name="conv"
        )
    def __call__(self, x):
        return self.conv1d_layer(x)
  
f = hk.transform(lambda x: CNN()(x))

@jax.jit
def cnn_function(state_vec, W, b):
    x = state_vec.reshape((1, -1))
    params = {'cnn/~/conv': {'w': W, 'b': b}}
    key = jax.random.PRNGKey(42)    # could pass the key as an argument to this function
    return f.apply(params, key, x).reshape((-1,))

Finally, we need to instruct the JAX compiler to use the above function when compiling and executing the domain and instance; it will be included automatically as part of the computation graph during compilation:

In [5]:
env = pyRDDLGym.make("ExternalFuncDomainCNN_standalone", "1", 
                     backend=JaxRDDLSimulator,
                     backend_kwargs={'python_functions': {'CNNFunction': cnn_function}})

Let's execute the environment with a random policy:

In [6]:
agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
agent.evaluate(env, episodes=1, verbose=True)

initial state = 
     state___o1 = 1.0  state___o2 = 2.0  state___o3 = 3.0  state___o4 = 4.0 
------------------------------------------------------------------------------------------------------------------------------------------------
step   = 0
action = 
     action___o1 = -3.123196840286255  action___o2 = 9.072826385498047   action___o3 = -2.21956729888916   action___o4 = 9.075740814208984  
    
state  = 
     state___o1 = 2.0  state___o2 = 2.0  state___o3 = 2.0  state___o4 = 2.0 
reward = 144.1428680419922
done   = False
------------------------------------------------------------------------------------------------------------------------------------------------
step   = 1
action = 
     action___o1 = -8.618789672851562  action___o4 = 5.821589469909668   action___o2 = 2.3692588806152344  action___o3 = 9.883711814880371  
    
state  = 
     state___o1 = 1.6  state___o2 = 1.6  state___o3 = 1.6  state___o4 = 1.6 
reward = 189.1278839111328
done   = False
------------------------

{'mean': np.float64(764.8373718261719),
 'median': np.float64(764.8373718261719),
 'min': np.float64(764.8373718261719),
 'max': np.float64(764.8373718261719),
 'std': np.float64(0.0)}