# Saving and loading trained policies in JaxPlan. 

In this notebook, we illustrate the procedure of saving and loading trained JaxPlan policies.

Start by installing the required packages:

In [1]:
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Import the required packages:

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import pickle

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will load the Wildfire example to illustrate the process:

In [3]:
env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)

Let's now train a fresh policy network to solve this problem:

In [4]:
planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy(), pgpe=None, optimizer_kwargs={'learning_rate': 0.01})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=30)
agent.evaluate(env, episodes=100)

[32m[INFO] JAX gradient compiler will cast p-vars {'burning', 'NEIGHBOR', "out-of-fuel'", 'cut-out', "burning'", 'out-of-fuel', 'put-out', 'TARGET'} to float.[0m
[32m[INFO] JAX gradient compiler will cast CPFs {"burning'", "out-of-fuel'"} to float.[0m
[32m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[32m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m
[33m[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.[0m


      0 it /    -6698.47021 train /    -4108.43750 test /    -4108.43750 best / 0 status /      0 pgpe:  11%| | 00:02 ,

[31m[FAIL] Compiler encountered the following error(s) in the training model:
    Casting occurred that could result in loss of precision.[0m


   1769 it /    -2767.67480 train /     -768.50000 test /     -477.21875 best / 5 status /      0 pgpe: 100%|█| 00:29 ,





{'mean': np.float64(-615.4),
 'median': np.float64(-210.0),
 'min': np.float64(-8035.0),
 'max': np.float64(-210.0),
 'std': np.float64(1276.131983769704)}

To save the model, we will just pickle the final parameters of the policy network:

In [5]:
with open('wildfire_drp.pickle', 'wb') as file:
    pickle.dump(agent.params, file)

Now, let's load the pickled parameters and pass them to a newly-instantiated controller:

In [6]:
with open('wildfire_drp.pickle', 'rb') as file:
    params = pickle.load(file)    
new_planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy())
new_agent = JaxOfflineController(new_planner, params=params, print_summary=False)

[32m[INFO] JAX gradient compiler will cast p-vars {'burning', 'NEIGHBOR', "out-of-fuel'", 'cut-out', "burning'", 'out-of-fuel', 'put-out', 'TARGET'} to float.[0m
[32m[INFO] JAX gradient compiler will cast CPFs {"burning'", "out-of-fuel'"} to float.[0m
[32m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[32m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m


Note that in this case there is no pre-training of the policy. Let's evaluate the agent to make sure it still performs the same as the trained one:

In [7]:
new_agent.evaluate(env, episodes=100)

{'mean': np.float64(-523.8),
 'median': np.float64(-210.0),
 'min': np.float64(-8270.0),
 'max': np.float64(-210.0),
 'std': np.float64(1204.700402589789)}

Indeed, the performance is quite similar.