# GridWorld Solution

Written by Berktug Ozkan.

Related post link, [click here](https://spaceymonk.github.io/gridworld-rl-again.html).

## Dependencies

In [1]:
import numpy as np
import random

## Creating Tables

### Action & Observation Spaces

In [2]:
action_space      = (4,)  # right, left, up, down
observation_space = (3, 4)  # the default dimensions of Frozen Lake problem

In [3]:
action_space_size      = np.cumprod(action_space, dtype=np.int32)[-1]
observation_space_size = np.cumprod(observation_space, dtype=np.int32)[-1]

### Q Table

Stores the quality values of state--action pairs.

  * _rows_, states (0 to 11)
  * _columns_, actions (0 to 3; right, left, up, down respectively)

I have tried to preserve the shape of the [Empty Q Table](https://spaceymonk.github.io/assets/2022-08-02/GridWorld_-_Q_Table.pdf) mentioned in the post.

In [4]:
# create current Q values table to fill
current_q_table = np.zeros( (observation_space_size, action_space_size) )

# we know the values of the terminal states
current_q_table[3] = np.ones(action_space_size)
current_q_table[7] = -np.ones(action_space_size)

current_q_table

array([[ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [-1., -1., -1., -1.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.]])

### State Values

Value function of the states.

In [5]:
# value table to fill, it will act as observations
states = np.zeros(observation_space)

# we know the values of the terminal states
states[np.unravel_index(3, observation_space)] = 1.00
states[np.unravel_index(7, observation_space)] = -1.00

states

array([[ 0.,  0.,  0.,  1.],
       [ 0.,  0.,  0., -1.],
       [ 0.,  0.,  0.,  0.]])

## Training Loop

Main loop to fit Q table. You can change the [parameters](#Parameters).

### Parameters

In [6]:
n_iterations = 100  # number iterations to run the loop
gamma        = 1  # discount

In [7]:
reward = -0.04  # for every move

### Driver

#### Handling "bumbing" on the walls

Below cell defines a function to check whether the given state is a wall or not.

By defining such function it is easy to adapt other mazes. For example, if "state 1" becomes a wall in another
maze, we can simply add:

``` python
return row * observation_space[1] + col == 5 or row * observation_space[1] + col == 1
```

One may want to introduce an array of "wall" states to shorten the line.

In [8]:
def is_wall(row, col):
    return row * observation_space[1] + col == 5

#### A different approach!

If you haven't noticed yet, I hard-coded the terminal and wall states. Keep this in mind if you want to introduce another maze.

(1) If you look at the line 30 where we update the state values, we are doing this inside the loop so every state uses the *brand new estimated value* of the neighbouring state.

(2) On the other hand, on the commented line 37, you can see the update operation can be done at the end of the loop. If you choose this approach, estimated values uses the previous loop's values. _Do not forget to remove the first `states.copy()` in the history array (line 2)._

So which one? In my opinion;

 - I think this last option is more appropriate to parallelization due to the lack of data race. Data is only read and written on another memory block. But in the first case, data needs to be written and only then, it can be read.
 - But iteration count may change. Meaning the former may converge faster in terms of iterations. For example, it took,
   | update per state (first case) | update all at once (second case) |
   | --- | --- |
   | 13 steps | 25 steps |
 - So, it really depends on the problem and available hardware. If there are so many states which perfectly fits into the GPU, then use the latter. Otherwise, you can use the first one.

In [9]:
q_table_history = []  # history of Q tables
states_history  = [states.copy()]  # history of state values


for i in range(0, n_iterations):
    for state in range(observation_space_size):  # for every state,
        
        # if terminal state or wall encountered, ignore
        if state == 3 or state == 5 or state == 7:
            continue
        
        row, col = np.unravel_index(state, observation_space)

        # neighbouring state values
        right_neigbour = states[row, col+1] if col+1 >= 0 and col+1 < observation_space[1] and not is_wall(row, col+1) else states[row, col]
        left_neigbour  = states[row, col-1] if col-1 >= 0 and col-1 < observation_space[1] and not is_wall(row, col-1) else states[row, col]
        up_neigbour    = states[row-1, col] if row-1 >= 0 and row-1 < observation_space[0] and not is_wall(row-1, col) else states[row, col]
        down_neigbour  = states[row+1, col] if row+1 >= 0 and row+1 < observation_space[0] and not is_wall(row+1, col) else states[row, col]
        
        # calculate action values
        right = reward + gamma * (0.8 * right_neigbour + 0.10 * up_neigbour   + 0.10 * down_neigbour )
        left  = reward + gamma * (0.8 * left_neigbour  + 0.10 * up_neigbour   + 0.10 * down_neigbour )
        up    = reward + gamma * (0.8 * up_neigbour    + 0.10 * left_neigbour + 0.10 * right_neigbour)
        down  = reward + gamma * (0.8 * down_neigbour  + 0.10 * left_neigbour + 0.10 * right_neigbour)
        
        # set action values to Q table in the current state
        current_q_table[state] = np.array([right, left, up, down])
        
        # update state values
        states[row, col] = np.max(current_q_table[state])
    
    # save to history
    q_table_history.append(current_q_table.copy())
    states_history.append(states.copy())
    
    # update states values
#     states = np.max(current_q_table, axis=1).reshape(observation_space)

print('State Values:')
print(states)
print('\nQ Table:')
print(current_q_table)

State Values:
[[ 0.81155822  0.86780822  0.91780822  1.        ]
 [ 0.76155822  0.          0.66027397 -1.        ]
 [ 0.70530822  0.65530822  0.61141553  0.38792491]]

Q Table:
[[ 0.81155822  0.76655822  0.77718322  0.73718322]
 [ 0.86780822  0.78280822  0.82718322  0.82718322]
 [ 0.91780822  0.81205479  0.8810274   0.675     ]
 [ 1.          1.          1.          1.        ]
 [ 0.72093322  0.72093322  0.76155822  0.67655822]
 [ 0.          0.          0.          0.        ]
 [-0.68707763  0.64114155  0.66027397  0.41515982]
 [-1.         -1.         -1.         -1.        ]
 [ 0.63093322  0.67093322  0.70530822  0.66030822]
 [ 0.58019406  0.65530822  0.61591895  0.61591895]
 [ 0.39750888  0.61141553  0.59254249  0.55345573]
 [ 0.20913242  0.38792491 -0.74006596  0.37027397]]


## Analysis

After training completed, the results are stored in `current_q_table` and `states`.

To see the calculated values step by step you can use `q_table_history` and `states_history` arrays.

---

Below cell uses `ipywidgets` package to introduce an interactive slide bar to travel through time to see the calculation steps.

If the output of the below cell did not rendered, please run the notebook yourselves.

_Have fun!_

In [10]:
from ipywidgets import interact  # to use slidebar
from IPython.display import display, HTML  # to use better formatting in Jupyter Notebook


# Better display the 2D matrices
def display_table(data):
    return '<table><tr>{}</tr></table>'.format(
        '</tr><tr>'.join('<td>{}</td>'.format(
            '</td><td>'.join('{:0.3f}'.format(_) for _ in row)) for row in data))


# Function will receive the slide bar's value with parameter time_step
@interact(time_step=(0, n_iterations-1))
def show_history(time_step=0):
    display(HTML(f"""
    <div style="display:flex;flex-direction:row;align-items:center;justify-content:space-around;text-align:center;">
      <div><h4>State Values (before iteration) </h4> {display_table(states_history[time_step])}</div>
      <div><h4>Q Values</h4> {display_table(q_table_history[time_step])}</div>
      <div><h4>State Values (after iteration) </h4> {display_table(states_history[time_step+1]) if time_step+1 < n_iterations else None}</div>
    </div>
    """))

interactive(children=(IntSlider(value=0, description='time_step', max=99), Output()), _dom_classes=('widget-inâ€¦