Author: Mike Gimelfarb
This directory provides:
- automated translation and compilation of RDDL description files into the JAX auto-diff library, which allows any RDDL domain to be converted to a differentiable simulator!
- powerful, fast, and very scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
Note
While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!). If you find it is not making sufficient progress, check out the PROST planner (for discrete spaces) or the deep reinforcement learning wrappers.
- Installation
- Running the Basic Examples
- Running from the Python API
- Writing a Configuration File for a Custom Domain
- Changing the pyRDDLGym Simulation Backend to JAX
- Computing the Model Gradients Manually
- Citing pyRDDLGym-jax
To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
pyRDDLGym>=2.0
tqdm>=4.66
jax>=0.4.12
optax>=0.1.9
dm-haiku>=0.0.10
tensorflow>=2.13.0
tensorflow-probability>=0.21.0
Additionally, if you wish to run the examples, you need rddlrepository>=2
, and run the automated tuning optimization, you will also need bayesian-optimization>=1.4.3
.
You can install this package, together with all of its requirements via pip:
pip install rddlrepository pyRDDLGym-jax
A basic run script is provided to run the Jax Planner on any domain in rddlrepository
, provided a config file is available (currently, only a limited subset of configs are provided as examples).
The example can be run as follows in a standard shell, from the install directory of pyRDDLGym-jax:
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
where:
domain
is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a validdomain.rddl
fileinstance
is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a validinstance.rddl
filemethod
is the planning method to use (i.e. drp, slp, replan)episodes
is the (optional) number of episodes to evaluate the learned policy
The method
parameter warrants further explanation. Currently we support three possible modes:
slp
is the basic straight line planner described in this paperdrp
is the deep reactive policy network described in this paperreplan
is the same asslp
except the plan is recalculated at every decision time step.
A basic run script is also provided to run the automatic hyper-parameter tuning. The structure of this stript is similar to the one above
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
where:
domain
is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)instance
is the instance identifier (i.e. 1, 2, ... 10)method
is the planning method to use (i.e. drp, slp, replan)trials
is the (optional) number of trials/episodes to average in evaluating each hyper-parameter settingiters
is the (optional) maximum number of iterations/evaluations of Bayesian optimization to performworkers
is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations =iters * workers
For example, copy and pasting the following will train the Jax Planner on the Quadcopter domain with 4 drones:
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
After several minutes of optimization, you should get a visualization as follows:
You can write a simple script to instantiate and run the planner yourself, if you wish:
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
# set up the environment (note the vectorized option must be True)
env = pyRDDLGym.make("domain", "instance", vectorized=True)
# create the planning algorithm
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
controller = JaxOfflineController(planner, **train_args)
# evaluate the planner
controller.evaluate(env, episodes=1, verbose=True, render=True)
env.close()
Here, we have used the straight-line (offline) controller, although you can configure the combination of planner and policy representation.
All controllers are instances of pyRDDLGym's BaseAgent
class, so they support the evaluate()
function to streamline interaction with the environment.
The **planner_args
and **train_args
are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
The simplest way to interface with the Planner for solving a custom problem is to write a configuration file with all the necessary hyper-parameters. The basic structure of a configuration file is provided below for a straight line planning/MPC style planner:
[Model]
logic='FuzzyLogic'
logic_kwargs={'weight': 20}
tnorm='ProductTNorm'
tnorm_kwargs={}
[Optimizer]
method='JaxStraightLinePlan'
method_kwargs={}
optimizer='rmsprop'
optimizer_kwargs={'learning_rate': 0.001}
batch_size_train=1
batch_size_test=1
[Training]
key=42
epochs=5000
train_seconds=30
The configuration file contains three sections:
[Model]
specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations, theweight
parameter for example dictates how well the approximation will fit to the true operation, andtnorm
specifies the type of fuzzy logic for relacing logical operations in RDDL (e.g.ProductTNorm
,GodelTNorm
,LukasiewiczTNorm
)[Optimizer]
generally specify the optimizer and plan settings, themethod
specifies the plan/policy representation (e.g.JaxStraightLinePlan
,JaxDeepReactivePolicy
), the SGD optimizer to use from optax, learning rate, batch size, etc.[Training]
specifies how long training should proceed, theepochs
limits the total number of iterations, whiletrain_seconds
limits total training time
For a policy network approach, simply change the configuration file like so:
...
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64]}
...
which would create a policy network with two hidden layers and ReLU activations.
The configuration file can then be passed to the planner during initialization. For example, the previous script here can be modified to set parameters from a config file as follows:
from pyRDDLGym_jax.core.planner import load_config
# load the config file with planner settings
planner_args, _, train_args = load_config("/path/to/config.cfg")
# create the planning algorithm
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
controller = JaxOfflineController(planner, **train_args)
...
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments, instead of the usual pyRDDLGym one:
import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
# create the environment
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
# evaluate the random policy
agent = RandomAgent(action_space=env.action_space,
num_actions=env.max_allowed_actions)
agent.evaluate(env, verbose=True, render=True)
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations. In any event, the simulation results using the JAX backend should match exactly those of the numpy-based backend.
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream. Fortunately, we provide a very convenient function for compiling the transition/step function P(s, a, s')
of the environment into JAX.
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
# set up the environment
env = pyRDDLGym.make("domain", "instance", vectorized=True)
# create the step function
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
compiled.compile()
step_fn = compiled.compile_transition()
This will return a JAX compiled (pure) function that requires 4 arguments:
key
is thejax.random.PRNGKey
key for reproducible randomnessactions
is the dictionary of action fluent JAX tensorssubs
is the dictionary of state-fluent and non-fluent JAX tensorsmodel_params
are the parameters of the differentiable relaxations, such asweight
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (pvar
), reward obtained (reward
), error codes (error
).
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using jax.grad()
or batched simulation using jax.vmap()
.
Compilation of entire rollouts is possible by calling the compile_rollouts
function, and providing a policy implementation that maps states (jax tensors) and tunable policy parameters to actions.
An example is provided to illustrate how you can define your own policy class and compute the return gradient manually.
The main ideas of this approach are discussed in the following preprint:
@article{taitler2022pyrddlgym,
title={pyRDDLGym: From RDDL to Gym Environments},
author={Taitler, Ayal and Gimelfarb, Michael and Gopalakrishnan, Sriram and Mladenov, Martin and Liu, Xiaotian and Sanner, Scott},
journal={arXiv preprint arXiv:2211.05939},
year={2022}
}
Many of the implementation details discussed come from the following literature, which you may wish to cite in your research papers: