In [27]:
import numpy as np
import time
from grid_world import standard_grid,negative_grid
SMALL_TOL=10e-10
ALL_Actions=['U','D','L','R',]
def print_values(V, g):
  print("Values:")
  for i in range(g.width):
    print("---------------------------")
    for j in range(g.height):
      v = V.get((i,j), 0)
      if v >= 0:
        print(" %.2f|" % v, end="")
      else:
        print("%.2f|" % v, end="") # -ve sign takes up an extra space
    print("")


def print_policy(P, g):
  print("Policy:")
  for i in range(g.width):
    print("---------------------------")
    for j in range(g.height):
      a = P.get((i,j), ' ')
      print("  %s  |" % a, end="")
    print("")
    

In [28]:
# Main starts here

In [29]:
grid = negative_grid()
print_values(grid.rewards,grid) #check the reward matrix

Values:
---------------------------
-0.10|-0.10|-0.10| 1.00|
---------------------------
-0.10| 0.00|-0.10|-1.00|
---------------------------
-0.10|-0.10|-0.10|-0.10|


In [40]:
## initialize random policies
policy={}
for s in grid.actions.keys():
    policy[s]=np.random.choice(ALL_Actions) # illegal actions will not move
print('initial policy')
print_policy(policy,grid)

### uniformly random actions ###
# initialize V(s) = 0
V = {}
states=grid.all_states()
for s in states:
    V[s] = 0
    if s in grid.actions:
        V[s]=np.random.random()
GAMMA = 0.9 # discount factor
# print_values(V,grid)

# value iteration
iter_n=0
while True:
    iter_n+=1
    biggest_diff=0
    for s in states:
        if s in policy:
            old_v=V[s]
            best_v=float('-inf')
            best_a=policy[s]
            for a in ALL_Actions:
                grid.set_state(s)
                r=grid.move(a)
                v_new=r+GAMMA*V[grid.current_state()]
                if best_v<v_new:
                    best_v=v_new
                    best_a=a
            V[s]=best_v # V[s] change in place
            policy[s]=best_a
            biggest_diff=max(biggest_diff,(old_v-V[s]))
    if biggest_diff<SMALL_TOL:
        break           
                    
#check the result
print(iter_n,end='\r')
print()
print_values(V,grid)
print_policy(policy,grid)

initial policy
Policy:
---------------------------
  U  |  D  |  R  |     |
---------------------------
  L  |     |  L  |     |
---------------------------
  R  |  R  |  L  |  D  |
4
Values:
---------------------------
 0.62| 0.80| 1.00| 0.00|
---------------------------
 0.46| 0.00| 0.80| 0.00|
---------------------------
 0.31| 0.46| 0.62| 0.46|
Policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |
