Based on Juliani's "Simple Reinforcement Learning with TensorFlow: Q-learning with tables and neural networks", I refactor it to a JAX implementation.  It mostly consists of a Q-table update based on the Bellman relation, e.g., a global optimum includes a local one.  I wanted to use JAX because it's faster than tensorflow, & functionally pure.  

References
1. 
https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-0-q-learning-with-tables-and-neural-networks-d195264329d0

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, tree_util, random, tree_util, lax
import jax
import jax.random as random
from functools import partial
from flax.training.train_state import TrainState
import plotly.express as px

In [4]:
class MazeEnv:
   # Board encoding:
      # 0 -> (S)tart
      # 1 -> (F)rozen
      # 2 -> (H)ole
      # 3 -> (G)oal
   def __init__(self,
                board = jnp.array([[0, 1, 2],
                                   [2, 1, 3]]),
                num_rows = 2,
                num_cols = 3,
                num_states = 6):
      self.board = board
      # self.board = jnp.array([[0, 1, 2, 2, 2],
      #                         [2, 1, 1, 1, 2],
      #                         [2, 2, 2, 1, 2],
      #                         [1, 2, 2, 1, 3]])
      self.num_rows = num_rows
      self.num_cols = num_cols
      self.num_states = num_states

   def _get_observation(self, state):
      return state
   
   def maybe_reset(self, state, done):
      return lax.cond(done, lambda s: 0, lambda s : s, state)

   def reset(self):
      return 0
   
   def step(self, state, action):
      # Action encoding:
      # 0 -> Left
      # 1 -> Right
      # 2 -> Up
      # 3 -> Down
      (row, col) = jnp.divmod(state, self.num_cols)
      col = lax.cond(action==0, 
                     lambda c: jnp.maximum(0,c-1),
                     lambda c: c,
                     col)
      col = lax.cond(action==1,
                     lambda c: jnp.minimum(c+1, self.num_cols-1),
                     lambda c: c,
                     col)
      row = lax.cond(action==2, 
                     lambda r: jnp.maximum(0,r-1),
                     lambda r: r,
                     row)
      row = lax.cond(action==3,
                     lambda r: jnp.minimum(r+1, self.num_rows-1),
                     lambda r: r,
                     row)
      # S -> 0, F -> 1, H -> 2, G -> 3
      start = self.board[row,col] == 0
      frozen = self.board[row,col] == 1
      hole = self.board[row,col] == 2
      goal = self.board[row,col] == 3
      reward = start * 0 + frozen * 0 + hole * -1 + goal * 1
      state = row*self.num_cols + col
      done = jnp.logical_or(goal, hole)
      state = self.maybe_reset(state, done)
      return state, reward, done

   def _tree_flatten(self):
      children = ()  # arrays / dynamic values
      aux_data = {'board': self.board,
                  'num_rows': self.num_rows,
                  'num_cols': self.num_cols,
                  'num_states': self.num_states}  # static values
      return (children, aux_data)

   @classmethod
   def _tree_unflatten(cls, aux_data, children):
      return cls(*children, **aux_data)

tree_util.register_pytree_node(MazeEnv,
                               MazeEnv._tree_flatten,
                               MazeEnv._tree_unflatten)

In [11]:
env = MazeEnv()
print(env.board)
print(jnp.reshape(jnp.arange(0,env.num_states),(env.num_rows,env.num_cols)))
key = random.PRNGKey(0)
Q = random.uniform(key, (env.num_states, 4))
gamma = 0.95
alpha = 0.8
state = 0
num_episodes = 10000
def inner_body(i, arguments):
   Q, state, env, gamma, alpha = arguments
   action = jnp.argmax(Q[state,:])
   new_state, reward, done = env.step(state, action)
   new_Q_value = Q[state, action] + alpha * (reward + gamma * jnp.max(Q[new_state,:]) - Q[state, action])
   Q = Q.at[state, action].set(new_Q_value)
   state = new_state
   return Q, state, env, gamma, alpha
def outer_body(i, arguments):
   Q, state, env, gamma, alpha = arguments
   state = env.reset()
   Q, state, env, _, _ = jax.lax.fori_loop(0, num_episodes, inner_body, (Q, state, env, gamma, alpha))
   return Q, state, env, gamma, alpha
Q, _, _, _, _ = jax.lax.fori_loop(0, num_episodes, outer_body, (Q, state, env, gamma, alpha))
print(Q)
#   Left Right Up Down
#   0    1     2  3

[[0 1 2]
 [2 1 3]]
[[0 1 2]
 [3 4 5]]
[[ 0.03404112  6.3277774   0.02240455 -0.25877023]
 [ 0.03598888 -0.70904     0.03627574  6.6608186 ]
 [ 0.08744538  0.7909105   0.35205448  0.53364205]
 [ 0.02900076  0.4168595   0.5802449   0.91486526]
 [-0.71709883  7.0113883   0.14695214  0.14695214]
 [ 0.51207185  0.90618336  0.7309413   0.95533276]]
