Based on Juliani's "Simple Reinforcement Learning with TensorFlow: Q-learning with tables and neural networks", I refactor it to a JAX implementation.  It mostly consists of a Q-table update based on the Bellman relation, e.g., a global optimum includes a local one.  I wanted to use JAX because it's faster than tensorflow, & functionally pure.  

References
1. 
https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-0-q-learning-with-tables-and-neural-networks-d195264329d0

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, tree_util, random, tree_util, lax
import jax
import jax.random as random
from functools import partial
from flax.training.train_state import TrainState
import plotly.express as px

In [4]:
class MazeEnv:
   def __init__(self,
                board = jnp.array([[0, 1, 2],
                                   [2, 1, 3]]),
                num_rows = 2,
                num_cols = 3,
                num_states = 6):
      # S -> 0, F -> 1, H -> 2, G -> 3
      # Start   Frozen  Hole    Goal
      self.board = board
      # self.board = jnp.array([[0, 1, 2, 2, 2],
      #                         [2, 1, 1, 1, 2],
      #                         [2, 2, 2, 1, 2],
      #                         [1, 2, 2, 1, 3]])
      self.num_rows = num_rows
      self.num_cols = num_cols
      self.num_states = num_states

   def _get_observation(self, state):
      return state
   
   def maybe_reset(self, state, done):
      return lax.cond(done, lambda s: 0, lambda s : s, state)

   def reset(self):
      return 0
   
   def step(self, state, action):
      (row, col) = jnp.divmod(state, self.num_cols)
      # Left Right Up Down
      # 0    1     2  3
      col = lax.cond(action==0, 
                     lambda c: jnp.maximum(0,c-1),
                     lambda c: c,
                     col)
      col = lax.cond(action==1,
                     lambda c: jnp.minimum(c+1, self.num_cols-1),
                     lambda c: c,
                     col)
      row = lax.cond(action==2, 
                     lambda r: jnp.maximum(0,r-1),
                     lambda r: r,
                     row)
      row = lax.cond(action==3,
                     lambda r: jnp.minimum(r+1, self.num_rows-1),
                     lambda r: r,
                     row)
      # S -> 0, F -> 1, H -> 2, G -> 3
      start = self.board[row,col] == 0
      frozen = self.board[row,col] == 1
      hole = self.board[row,col] == 2
      goal = self.board[row,col] == 3
      reward = start * 0 + frozen * 0 + hole * -1 + goal * 1
      state = row*self.num_cols + col
      done = jnp.logical_or(goal, hole)
      state = self.maybe_reset(state, done)
      return state, reward, done

   def _tree_flatten(self):
      children = ()  # arrays / dynamic values
      aux_data = {'board': self.board,
                  'num_rows': self.num_rows,
                  'num_cols': self.num_cols,
                  'num_states': self.num_states}  # static values
      return (children, aux_data)

   @classmethod
   def _tree_unflatten(cls, aux_data, children):
      return cls(*children, **aux_data)

tree_util.register_pytree_node(MazeEnv,
                               MazeEnv._tree_flatten,
                               MazeEnv._tree_unflatten)

In [6]:
env = MazeEnv()
print(env.board)
print(jnp.reshape(jnp.arange(0,env.num_states),(env.num_rows,env.num_cols)))
key = random.PRNGKey(0)
Q = random.uniform(key, (env.num_states, 4))
# print(Q)
gamma = 0.95
alpha = 0.8
num_episodes = 50
state = 0
for i in range(num_episodes):
    if state == (env.num_states-1):
        print('Reached goal')
        break
    # print(i)
    state = env.reset()
    j = 0
    while j < 4:
        j += 1
        action = jnp.argmax(Q[state,:])
        new_state, reward, done = env.step(state, action)
        new_Q_value = Q[state, action] + alpha * (reward + gamma * jnp.max(Q[new_state,:]) - Q[state, action])
        Q = Q.at[state, action].set(new_Q_value)
        state = new_state
        if done == True:
            break
print(Q)
#   Left Right Up Down
#   0    1     2  3

[[0 1 2]
 [2 1 3]]
[[0 1 2]
 [3 4 5]]
[[ 0.06541302  0.06471601  0.02240455 -0.25877023]
 [ 0.06506324 -0.70904     0.06424201  0.03644359]
 [ 0.08744538  0.7909105   0.35205448  0.53364205]
 [ 0.02900076  0.4168595   0.5802449   0.91486526]
 [ 0.27414513  0.14991808  0.9383501   0.5209162 ]
 [ 0.51207185  0.90618336  0.7309413   0.95533276]]


In [11]:
env = MazeEnv()
print(env.board)
print(jnp.reshape(jnp.arange(0,env.num_states),(env.num_rows,env.num_cols)))
key = random.PRNGKey(0)
Q = random.uniform(key, (env.num_states, 4))
gamma = 0.95
alpha = 0.8
state = 0
num_episodes = 10000
def inner_body(i, arguments):
   Q, state, env, gamma, alpha = arguments
   action = jnp.argmax(Q[state,:])
   new_state, reward, done = env.step(state, action)
   new_Q_value = Q[state, action] + alpha * (reward + gamma * jnp.max(Q[new_state,:]) - Q[state, action])
   Q = Q.at[state, action].set(new_Q_value)
   state = new_state
   return Q, state, env, gamma, alpha
def outer_body(i, arguments):
   Q, state, env, gamma, alpha = arguments
   state = env.reset()
   Q, state, env, _, _ = jax.lax.fori_loop(0, num_episodes, inner_body, (Q, state, env, gamma, alpha))
   return Q, state, env, gamma, alpha
Q, _, _, _, _ = jax.lax.fori_loop(0, num_episodes, outer_body, (Q, state, env, gamma, alpha))
print(Q)
#   Left Right Up Down
#   0    1     2  3

[[0 1 2]
 [2 1 3]]
[[0 1 2]
 [3 4 5]]
[[ 0.03404112  6.3277774   0.02240455 -0.25877023]
 [ 0.03598888 -0.70904     0.03627574  6.6608186 ]
 [ 0.08744538  0.7909105   0.35205448  0.53364205]
 [ 0.02900076  0.4168595   0.5802449   0.91486526]
 [-0.71709883  7.0113883   0.14695214  0.14695214]
 [ 0.51207185  0.90618336  0.7309413   0.95533276]]


In [5]:
# import flax.linen as nn
# from flax.training.train_state import TrainState
# import jax, jax.numpy as jnp
# import optax

# x = jnp.ones((1,2))
# y = jnp.ones((1,2))
# model = nn.Dense(2)
# variables = model.init(jax.random.key(0), x)
# tx = optax.adam(1e-3)

# state = TrainState.create(
#     apply_fn = model.apply,
#     params = variables['params'],
#     tx = tx)

# print(variables)

# def loss_fn(params, x, y):
#     predictions = state.apply_fn({'params': params}, x)
#     loss = optax.l2_loss(predictions=predictions, targets = y).mean()
#     return loss

# l = loss_fn(state.params, x, y)
# print(l)

# grads = jax.grad(loss_fn)(state.params, x, y)
# print(grads)

# state = state.apply_gradients(grads=grads)
# l = loss_fn(state.params, x, y)
# print(l)


{'params': {'kernel': Array([[-1.10456   , -1.1868286 ],
       [-0.9255007 ,  0.13144489]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}}
3.3514676
{'bias': Array([-1.5150304, -1.0276918], dtype=float32), 'kernel': Array([[-1.5150304, -1.0276918],
       [-1.5150304, -1.0276918]], dtype=float32)}
3.343844


In [24]:
import flax.linen as nn
from flax.training.train_state import TrainState
import jax, jax.numpy as jnp
import optax

# Q-Learning with neural networks (nn)
# Since there are more states in a typical game than a normal computer memory can hold we're forced to abandon a table
# We create a nn that inputs a hot-vector with length equal to the number of states and outputs a vector of 
# 4 Q-values -- one for each action.  Instead of updating a table, the nn will update via backpropagation & a loss function.

class Q_NN(nn.Module):                    # create a Flax Module dataclass
  out_dims: int

  @nn.compact
  def __call__(self, input):                  # input:                1 x num_states
    Qout = nn.Dense(self.out_dims)(input)     # Dense layer: num_states x num_actions
    return Qout                               # Qout:                 1 x num_actions
    # action = jnp.argmax(Qout)
    # return action
    # x = x.reshape((x.shape[0], -1))       # x: 1 x 16
    # return x

model = Q_NN(out_dims=4)  
x = jnp.empty((1, env.num_states))       # generate random data
variables = model.init(random.key(42), x)# initialize the weights
 
# print(model.tabulate(jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

tx = optax.adam(1e-1)
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = tx)

def loss_fn(params, input, nextQ):
    Qout = state.apply_fn({'params': params}, input)
    prediction = jnp.argmax(Qout)
    loss = optax.l2_loss(predictions=prediction, targets = nextQ)
    return loss

def train_step(state, batch):
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)



In [None]:
env = MazeEnv()
print(env.board)
print(jnp.reshape(jnp.arange(0,env.num_states),(env.num_rows,env.num_cols)))
key = random.PRNGKey(0)
Q = random.uniform(key, (env.num_states, 4))
# print(Q)
gamma = 0.95
alpha = 0.8
num_episodes = 5
state = 0
for i in range(num_episodes):
    if state == (env.num_states-1):
        print('Reached goal')
        break
    # print(i)
    state = env.reset()
    j = 0
    while j < 4:
        j += 1
        # action = jnp.argmax(Q[state,:])
        # new_state, reward, done = env.step(state, action)
        # new_Q_value = Q[state, action] + alpha * (reward + gamma * jnp.max(Q[new_state,:]) - Q[state, action])
        # Q = Q.at[state, action].set(new_Q_value)
        # state = new_state
        # if done == True:
        #     break
print(Q)
#   Left Right Up Down
#   0    1     2  3