In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, tree_util, random, tree_util, lax
import jax.random as random
import flax.linen as nn
from flax.training.train_state import TrainState
import optax
import plotly.express as px


In [3]:
class MazeEnv:
   def __init__(self,
                board = jnp.array([[0, 1, 2],
                                   [2, 1, 3]]),
                num_rows = 2,
                num_cols = 3,
                num_states = 6):
      # S -> 0, F -> 1, H -> 2, G -> 3
      self.board = board
      self.num_rows = num_rows
      self.num_cols = num_cols
      self.num_states = num_states

   def _get_observation(self, state):
      return state
   
   def maybe_reset(self, state, done):
      return lax.cond(done, lambda s: 0, lambda s : s, state)

   def reset(self):
      return 0
   
   def step(self, state, action):
      (row, col) = jnp.divmod(state, self.num_cols)
      # Left Right Up Down
      # 0    1     2  3
      col = lax.cond(action==0, 
                     lambda c: jnp.maximum(0,c-1),
                     lambda c: c,
                     col)
      col = lax.cond(action==1,
                     lambda c: jnp.minimum(c+1, self.num_cols-1),
                     lambda c: c,
                     col)
      row = lax.cond(action==2, 
                     lambda r: jnp.maximum(0,r-1),
                     lambda r: r,
                     row)
      row = lax.cond(action==3,
                     lambda r: jnp.minimum(r+1, self.num_rows-1),
                     lambda r: r,
                     row)
      # S -> 0, F -> 1, H -> 2, G -> 3
      start = self.board[row,col] == 0
      frozen = self.board[row,col] == 1
      hole = self.board[row,col] == 2
      goal = self.board[row,col] == 3
      reward = start * 0 + frozen * 0 + hole * -1 + goal * 1
      state = row*self.num_cols + col
      done = jnp.logical_or(goal, hole)
      state = self.maybe_reset(state, done)
      return state, reward, done

   def _tree_flatten(self):
      children = ()  # arrays / dynamic values
      aux_data = {'board': self.board,
                  'num_rows': self.num_rows,
                  'num_cols': self.num_cols,
                  'num_states': self.num_states}  # static values
      return (children, aux_data)

   @classmethod
   def _tree_unflatten(cls, aux_data, children):
      return cls(*children, **aux_data)

tree_util.register_pytree_node(MazeEnv,
                               MazeEnv._tree_flatten,
                               MazeEnv._tree_unflatten)

In [4]:
class Q_nn(nn.Module):
  out_dims: int
  @nn.compact
  def __call__(self, input):                  # input:                1 x num_states
    Qout = nn.Dense(self.out_dims)(input)     # Dense layer: num_states x num_actions
    return Qout                               # Qout:                 1 x num_actions

In [5]:
env = MazeEnv()
model = Q_nn(out_dims=4)
state_to_tensor = lambda state: jnp.identity(env.num_states)[state:state+1]
tensor_to_state = lambda tensor_state: jnp.argmax(tensor_state)
tensor_state = state_to_tensor(0)
params = model.init(random.key(42), tensor_state)
tx = optax.adam(1e-1)
train_state = TrainState.create(
    apply_fn = model.apply,
    params = params['params'],
    tx = tx)

In [16]:
# @jit
# def train_step(train_state, batch):
#   def loss_fn(params):
#     Qs = train_state.apply_fn({'params': params}, batch['tensor_state'])
#     loss = optax.l2_loss(predictions=Qs, labels=batch['Q-target']).mean()
#     return loss
#   grad_fn = grad(loss_fn)
#   grads = grad_fn(train_state.params)
#   train_state = train_state.apply_gradients(grads=grads)
#   return train_state

In [17]:
def train_step(train_state, batch, env):
  def loss_fn(params):
    s = tensor_to_state(batch['tensor_state'])
    allQ = train_state.apply_fn({'params': params}, batch['tensor_state'])
    a = jnp.argmax(allQ)  # Left Right Up Down |-> 0 1 2 3
    s1, r, d = env.step(s, a)
    s1_tensor = state_to_tensor(s1)
    Q1 = train_state.apply_fn({'params': params}, s1_tensor)
    maxQ1 = jnp.max(Q1)
    gamma = 0.95
    alpha = 0.8
    targetQ = allQ
    targetQ.at[a].set(r + gamma*maxQ1)
    loss = optax.l2_loss(predictions=allQ, labels=targetQ).sum()
    return loss
  grad_fn = grad(loss_fn)
  grads = grad_fn(train_state.params)
  train_state = train_state.apply_gradients(grads=grads)
  return train_state

In [17]:
num_episodes = 5
gamma = 0.95
alpha = 0.8
for i in range(num_episodes):
    state = env.reset()
    train_state = train_step(train_state)
    # j = 0
    # while j < 4:
    #     j += 1
    #     tensor_state = state_to_tensor(state)
    #     allQ = model.apply(params, tensor_state)
    #     action = jnp.argmax(allQ)  # Left Right Up Down |-> 0 1 2 3
    #     new_state, reward, done = env.step(state, action)
    #     new_tensor_state = state_to_tensor(new_state)
    #     Q1 = model.apply(params, new_tensor_state)
    #     maxQ1 = jnp.max(Q1)

    #     targetQ = allQ
    #     targetQ.at[action].set(reward + gamma*maxQ1 )
    #     allQ = model.apply(params, tensor_state)

