# Getting started with MDPax

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joefarrington/mdpax/blob/main/examples/getting_started.ipynb)

This notebook demonstrates MDPax's key features through increasingly complex examples. We'll start with a simple forest management problem and work our way up to larger, more realistic problems.

If you're running the notebook in Colab, you should verify that you're using a GPU instance. Click Runtime > Change runtime type and ensure "GPU" is selected as the Hardware accelerator. You can confirm GPU availability by running `!nvidia-smi` in a code cell.

## Prerequisites

If you're new to Markov Decision Processes (MDPs), you may find these introductory resources useful:
- [Reinforcement Learning: An Introduction - Chapter 3 | Sutton & Barto](http://incompleteideas.net/book/RLbook2020.pdf)
- 📺 [Markov Decision Processes 1 - Value Iteration | Stanford CS221](https://www.youtube.com/watch?v=9g32v7bK3Co)

## Installation and imports

In [26]:
import sys

try:
    # Other dependencies will be installed on Colab already
    import mdptoolbox
    import mdpax
except ImportError:
    if 'google.colab' in sys.modules:
        # Automatically install mdpax if running in Colab, environment is temporary
        !pip install "mdpax[examples] @ git+https://github.com/joefarrington/mdpax.git"
    else:
        print("Dependencies not installed. Please follow the installation instructions in the README: https://github.com/joefarrington/mdpax")

In [27]:
import jax

# In general we recommend using double precision and it is particularly helpful when
# performing comparisons with pymdptoolbox which uses NumPy and therefore defaults to
# double precision.
jax.config.update("jax_enable_x64", True)

## Simple example: forest management

Let's start with a simple example problem introduced in pymdptoolbox, an alternative library for solving MDPs in Python, so that we can compare our results. 

This problem involves deciding whether we should cut down a forest, or wait to let it mature.

The state is the current age of the forest and our actions are 0 (wait) and 1 (cut). We receive a reward of 1 if we cut down the forest before it is mature, a reward of $r_1$ if we wait in the oldest state, and a reward of $r_2$ if we cut the forst in the oldest state. There is a risk of fire occurring, with probability $p$ of a fire at each timestep. If we choose to cut down the forest, or if there is a fire, the forest returns to age 0 (newly planted).

The two key base classes in mdpax are `Problem` and `Solver` - the `Problem` class is used to define the MDP (in this case the forest problem) and the `Solver` class is used to define algorithms for fitting policies (in this case, value iteration).

In [28]:
import jax.numpy as jnp
import mdptoolbox
import numpy as np

from mdpax.problems.forest import Forest
from mdpax.solvers.value_iteration import ValueIteration

In [29]:
# Create and solve the basic forest problem with MDPax
problem = Forest(S=3, r1=4.0, r2=2.0, p=0.1)  # Small forest with 3 states
solver = ValueIteration(problem, gamma=0.9, epsilon=0.01)
solution = solver.solve()

[32m2025-01-05 18:30:28.111[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with forest problem[0m
[32m2025-01-05 18:30:34.576[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m123[0m - [1mCheckpointing not enabled[0m
[32m2025-01-05 18:30:34.816[0m | [1mINFO    [0m | [36mmdpax.solvers.value_iteration[0m:[36msolve[0m:[36m497[0m - [1mIteration 1 span: 4.0000[0m
[32m2025-01-05 18:30:34.935[0m | [1mINFO    [0m | [36mmdpax.solvers.value_iteration[0m:[36msolve[0m:[36m497[0m - [1mIteration 2 span: 2.4300[0m
[32m2025-01-05 18:30:34.949[0m | [1mINFO    [0m | [36mmdpax.solvers.value_iteration[0m:[36msolve[0m:[36m497[0m - [1mIteration 3 span: 0.8100[0m
[32m2025-01-05 18:30:34.955[0m | [1mINFO    [0m | [36mmdpax.solvers.value_iteration[0m:[36msolve[0m:[36m497[0m - [1mIteration 4 span: 0.0000[0m
[32m2025-01-05 18:30:34.957[0m | [1mINFO    

In [30]:
# solution is a dataclass from which we can extract the values, policy, and iteration count
mdpax_values = solution.values
mdpax_policy = solution.policy
mdpax_iteration = solution.info.iteration

# Print the solution from MDPax
print("MDPax Solution:")
print("--------------")
print(f"Values:\n{np.round(mdpax_values.flatten(), 4)}")  # State values
print(f"Policy:\n{mdpax_policy.flatten()}")  # Optimal actions
print(f"Iterations to converge: {mdpax_iteration}")  # From solver info

MDPax Solution:
--------------
Values:
[ 5.052  8.292 12.292]
Policy:
[0 0 0]
Iterations to converge: 4


MDPax relies on a functional description of the MDP, using the `Problem` class. This includes defining the probability of a random event given a state and an action, and a deterministic transition function that gives the reward and the next state given a state, action and random event. 

Many other libraries, such as pymdptoolbox require the user to provide the transition matrix $\mathbf{P}$ and a reward matrix $\mathbf{R}$ for the MDP. 

Transition matrix $\mathbf{P}$ has dimensions (`n_actions`, `n_states`, `n_states`) and element $\mathbf{P}_{a,s,s'}$ is the probability of transitioning to state $s'$ when taking action $a$ in state $s$. 

Reward matrix $\mathbf{R}$ has dimensions (`n_states`, `n_actions`) and element $\mathbf{R}_{s,a}$ gives the expected reward when taking action $a$ in state $s$.

The MDPax `Problem` class has a built-in method for constructing these matrices based on the functions describing the problem, which we can use to construct $\mathbf{P}$ and $\mathbf{R}$ and solve the forest management problem using pymdptoolbox to check our solution.

In [31]:
# Get transition and reward matrices for comparison
P, R = problem.build_transition_and_reward_matrices()
# Convert to numpy arrays for pymdptoolbox
P = np.array(P)
R = np.array(R)

In [32]:
# Solve with pymdptoolbox
vi = mdptoolbox.mdp.ValueIteration(P, R, discount=0.9, epsilon=0.01)
vi.run()

In [33]:
# Extract the solution from pymdptoolbox class
toolbox_values = vi.V
toolbox_policy = vi.policy
toolbox_iteration = vi.iter


# Print the solution from pymdptoolbox
print("\npymdptoolbox Solution:")
print("--------------------")
print(f"Values:\n{np.round(toolbox_values, 4)}")
print(f"Policy:\n{toolbox_policy}")
print(f"Iterations to converge: {toolbox_iteration}")


pymdptoolbox Solution:
--------------------
Values:
[ 5.052  8.292 12.292]
Policy:
(0, 0, 0)
Iterations to converge: 4


In [34]:
# Verify solutions match
print("\nSolutions match?")
print(f"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}")
print(f"Policies match?: {np.array_equal(mdpax_policy.flatten(), np.array(toolbox_policy))}")
print(f"Number of iterations match?: {mdpax_iteration == toolbox_iteration}")


Solutions match?
Values close?: True
Policies match?: True
Number of iterations match?: True


For this very small problem, pymdptoolbox will be faster than MDPax due to data trasfer costs moving data to and from GPU and the upfront costs for [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) in JAX.

## A larger problem: perishable inventory management with substitution

Perishable inventory management problems are known to be computationally challenging to solve exactly (e.g. with value iteration) because the state must represent the age-profile of the stock (how many units of each age are held) and therefore the size of the state space grows exponentially with the maximum useful life of the product. 

In this example, we consider a perishable inventory management problem introduced by [Hendrix et al. (2019)](https://doi.org/10.1002/cmm4.1027). 

The decision maker must place a replenishment order each day for two perishable products, product A and product B. Orders are placed in the morning and arrive immediately before the start of the next period. Demand for each product each day is random. Some customers who want product B may accept product A, and substitutions are made once demand for product A has been met as far as possbile. 

The goal is to maximise average daily profits (sales revenue less an ordering cost per unit), and therefore we use relative value iteration, (with no discounting of future rewards), as the `Solver`. 

The smallest example considered by Hendrix et al. has 11,025 states and 105 actions. Since $\mathbf{P}$ has dimensions (`n_actions`, `n_states`, `n_states`) the transition matrix for the problem would have $(105 \times 11,025 \times 11,025) = 13\text{Bn}$ elements. Just storing this matrix as 64-bit floats would require over 100GB of RAM!

So, to start with, so that we can compare our results with `pymdptoolbox`, we'll look at a smaller version of the problem with 625 states and 25 actions.

<b> A note on sparsity:</b> The comments on the size of the transition matrices in this introductory notebook do not take potential sparsity into account. pymdptoolbox has support for sparse arrays and this would reduce the memory requirements required to represent the transition matrices. Support for sparse arrays in JAX is currently [experimental](https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html). We may investigate the potential benefits of sparsity as part of a future release.

### Basic case

In [35]:
from mdpax.problems.perishable_inventory.hendrix_two_product import (
    HendrixTwoProductPerishable,
)
from mdpax.solvers.relative_value_iteration import RelativeValueIteration

problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded
                                                  demand_poisson_mean_a=2, # Mean daily demand for product A
                                                  demand_poisson_mean_b=2, # Mean daily demand for product B
                                                  max_order_quantity_a=4, # Maximum order quantity for product A
                                                  max_order_quantity_b=4) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")                             

Number of states: 625
Number of actions: 25


As above, we'l first solve it using MDPax.

In [36]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()
mdpax_values = solution.values
mdpax_policy = solution.policy
mdpax_average_daily_profit = solution.info.gain

[32m2025-01-05 18:30:35.992[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:30:36.725[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m123[0m - [1mCheckpointing not enabled[0m
[32m2025-01-05 18:30:38.701[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 1: span: 2.84548, gain: 7.8527[0m
[32m2025-01-05 18:30:38.720[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 2: span: 0.60450, gain: 1.0389[0m
[32m2025-01-05 18:30:38.734[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 3: span: 0.10172, gain: 1.5746[0m
[32m2025-01-05 18:30:38.749[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0

The gain converges to the mean reward per timestep, so we can see that the mean profit per day is $1.54.

Next, because this is a small problem, we can build $\mathbf{P}$ and $\mathbf{R}$ and solve the problem using pymdptoolbox. Also, for this problem, the value function is initialized using the one-step ahead revenue. The `RelativeValueIterationSolver` automatically calculated this from the `Problem`, but we will need to provide it manually to `mdptoolbox`.

In [37]:
# Compute P, R and initial values from Problem
P, R = problem.build_transition_and_reward_matrices()
initial_values = jax.vmap(problem.initial_value)(problem.state_space)
P = np.array(P)
R = np.array(R)
initial_values = np.array(initial_values)

In [38]:
# Run relative value iteration with mpdtoolbox
rvi = mdptoolbox.mdp.RelativeValueIteration(P, R, epsilon=1e-4)
rvi.V = initial_values
rvi.run()

# Extract the solution from pymdptoolbox class
toolbox_values = rvi.V
toolbox_policy = rvi.policy
toolbox_average_daily_profit = rvi.average_reward

pymdptoolbox gives us the index of the best action. This was fine in the Forest example because the actions are not numeric and were were only identified by an index. In this example, where an action is a order quantitity for each of product A and product B, we need to look up the actual action to compare to the policy from MDPax. 

We can do this be indexing into the `action_space` attribute of our `Problem`. See our [tutorial](https://mdpax.readthedocs.io/en/latest/notebooks/create_custom_problem.html) on implementing your own problem for more information on state, action and event spaces.

In [39]:
toolbox_policy = problem.action_space.take(jnp.array(toolbox_policy),axis=0)

In [40]:
# Verify solutions match
print("\nSolutions match?")
print(f"Mean daily profit matches?: {np.allclose(mdpax_average_daily_profit,toolbox_average_daily_profit, rtol=1e-2)}")
print(f"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}")
print(f"Policies match?: {np.array_equal(mdpax_policy, toolbox_policy)}")
print(f"Number of iterations match?: {mdpax_iteration == toolbox_iteration}")


Solutions match?
Mean daily profit matches?: True
Values close?: True
Policies match?: True
Number of iterations match?: True


### Larger cases

We'll now consider two larger versions of the problem which were included in Hendrix et al. (2019).

#### Case 1

In [41]:
problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded
                                      demand_poisson_mean_a=7, # Mean daily demand for product A
                                      demand_poisson_mean_b=3, # Mean daily demand for product B
                                      max_order_quantity_a=14, # Maximum order quantity for product A
                                      max_order_quantity_b=6) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")  

Number of states: 11025
Number of actions: 105


In [42]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()

[32m2025-01-05 18:30:50.040[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:30:50.730[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m123[0m - [1mCheckpointing not enabled[0m
[32m2025-01-05 18:30:51.802[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 1: span: 6.51383, gain: 19.9606[0m
[32m2025-01-05 18:30:51.871[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 2: span: 1.24455, gain: 3.4498[0m
[32m2025-01-05 18:30:51.942[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 3: span: 0.30959, gain: 4.3591[0m
[32m2025-01-05 18:30:52.007[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[

This takes less than 10s, including setting up the problem, on a Google Colab GPU instance compared to the 206s reported by Hendrix et al. for their implementation using MATLAB on the CPU. 

#### Case 2

Hendrix et al. reported that the largest problem they were able to solve within a week had 1.2Mn states, and took 80 hours. Using MDPax on a Google Colab GPU instance it should take less than 3 minutes.

Storing the transition matrix for this problem as 64-bit floats would require over 1PB (or 1Mn GB) of RAM!

In [43]:
problem = HendrixTwoProductPerishable(max_useful_life = 3, # Products can be used for 2 periods after arrival, then are discarded
                                      demand_poisson_mean_a=7, # Mean daily demand for product A
                                      demand_poisson_mean_b=3, # Mean daily demand for product B
                                      max_order_quantity_a=20, # Maximum order quantity for product A
                                      max_order_quantity_b=4) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")

Number of states: 1157625
Number of actions: 105


In [44]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()

[32m2025-01-05 18:31:00.353[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:31:03.585[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m123[0m - [1mCheckpointing not enabled[0m
[32m2025-01-05 18:31:14.519[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 1: span: 6.54445, gain: 19.9912[0m
[32m2025-01-05 18:31:24.542[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 2: span: 6.27677, gain: 9.7235[0m
[32m2025-01-05 18:31:34.556[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 3: span: 1.48057, gain: 3.4468[0m
[32m2025-01-05 18:31:44.547[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[

## Checkpointing

Some large problems can take a long time, so MDPax supports checkpointing so that you can restart a run from a checkpoint if there is a problem.

Checkpointing is not enabled by default, because it is not very useful for smaller problems. You can activate it by setting `checkpoint_frequency` > 1 when instantiating a solver. The solver will then store a checkpoint every `checkpoint_frequency` iterations, and once it meets the convergence threshold.

Let's start by running a problem to convergence to get a reference policy. 

In [45]:
problem = HendrixTwoProductPerishable(max_useful_life = 2, 
                                      demand_poisson_mean_a=5, 
                                      demand_poisson_mean_b=5, 
                                      max_order_quantity_a=10, 
                                      max_order_quantity_b=10)
solver_a = RelativeValueIteration(problem, epsilon=1e-4)
result_from_full_run = solver_a.solve()

[32m2025-01-05 18:33:26.465[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:33:27.045[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m123[0m - [1mCheckpointing not enabled[0m
[32m2025-01-05 18:33:27.956[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 1: span: 6.52970, gain: 19.9641[0m
[32m2025-01-05 18:33:28.057[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 2: span: 1.23312, gain: 3.4442[0m
[32m2025-01-05 18:33:28.151[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 3: span: 0.29569, gain: 4.3579[0m
[32m2025-01-05 18:33:28.246[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[

Now, let's imagine the run got interrupted. To mimic that, we set `max_iterations` to less than the number of iterations required for convergence and activate checkpointing. We'll save a checkpoint every iteration, only keep the most recent checkpoint, and save them in directory `checkpoints/getting_started/incomplete_run`.

In [46]:
solver_b = RelativeValueIteration(problem, epsilon=1e-4, checkpoint_frequency=1, checkpoint_dir="checkpoints/getting_started/incomplete_run")
result_from_incomplete_run = solver_b.solve(max_iterations=5)

[32m2025-01-05 18:33:29.778[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:33:30.825[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m143[0m - [1mFull checkpointing enabled with problem and solver reconstruction[0m
[32m2025-01-05 18:33:30.826[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m152[0m - [1mSaving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/incomplete_run[0m
[32m2025-01-05 18:33:31.759[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 1: span: 6.52970, gain: 19.9641[0m
[32m2025-01-05 18:33:31.902[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 2: span: 1.2

In [47]:
solver_c = RelativeValueIteration.restore(checkpoint_dir="checkpoints/getting_started/incomplete_run", new_checkpoint_dir="checkpoints/getting_started/continued_run")

[32m2025-01-05 18:33:34.391[0m | [1mINFO    [0m | [36mmdpax.core.solver[0m:[36m__init__[0m:[36m159[0m - [1mSolver initialized with hendrix_two_product problem[0m
[32m2025-01-05 18:33:35.015[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m143[0m - [1mFull checkpointing enabled with problem and solver reconstruction[0m
[32m2025-01-05 18:33:35.015[0m | [1mINFO    [0m | [36mmdpax.utils.checkpointing[0m:[36m_setup_checkpointing[0m:[36m152[0m - [1mSaving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/continued_run[0m


In [48]:
print(f"Values restored correctly: {np.all(solver_c.values == result_from_incomplete_run.values)}")
print(f"Iteration restored correctly: {solver_c.iteration == result_from_incomplete_run.info.iteration}")

Values restored correctly: True
Iteration restored correctly: True


In [49]:
result_from_continued_run = solver_c.solve()

[32m2025-01-05 18:33:36.040[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 6: span: 0.01627, gain: 4.4977[0m
[32m2025-01-05 18:33:36.180[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 7: span: 0.00502, gain: 4.5060[0m
[32m2025-01-05 18:33:36.276[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 8: span: 0.00201, gain: 4.5021[0m
[32m2025-01-05 18:33:36.370[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 9: span: 0.00060, gain: 4.5032[0m
[32m2025-01-05 18:33:36.463[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iteration[0m:[36msolve[0m:[36m169[0m - [1mIteration 10: span: 0.00040, gain: 4.5033[0m
[32m2025-01-05 18:33:36.555[0m | [1mINFO    [0m | [36mmdpax.solvers.relative_value_iterat

In [50]:
print(f"Policy from restored run same as full run: {np.all(result_from_continued_run.policy == result_from_full_run.policy)}")

Policy from restored run same as full run: True


## Next Steps

- Try the next [tutorial](https://mdpax.readthedocs.io/en/latest/notebooks/create_custom_problem.html) to learn how to implement your own problems using MDPax [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joefarrington/mdpax/blob/main/examples/create_custom_problem.ipynb)
- Read the [MDPax documentation](https://mdpax.readthedocs.io/en/latest/index.html) for detailed API reference