In [1]:
'''
gridworld_terminal_DP_helper.py
'''

# libraries
import numpy as np
import sys
from tqdm import tqdm

# state indexing: (x,y) -- for 4x4 grid
# [
#     (0,0), (0,1), (0,2), (0,3)
#     (1,0), (1,1), (1,2), (1,3)
#     (2,0), (2,1), (2,2), (2,3)
#     (3,0), (3,1), (3,2), (3,3)
# ]

# left (0), up (1), right (2), down (3)
v_actions = [np.array([0, -1]),
            np.array([-1, 0]),
            np.array([0, 1]),
            np.array([1, 0])]

i_nActions = len(v_actions)

def target_policy_agent(state, gvPolicy, action=None):
    # print(state, action)
    # return the given action and the prob of selecting that action at given state
    return v_actions[action], gvPolicy[action]

def is_terminal(state, ginDim):
    x, y = state
    return (x == 0 and y == 0) or (x == ginDim - 1 and y == ginDim - 1)

def step(giSt, giAc, ginDim):
    giSt = np.array(giSt)
    next_state = (giSt + giAc).tolist()
    x, y = next_state
    reward = -1

    # if you exit the grid, return the original state
    if x < 0 or x >= ginDim or y < 0 or y >= ginDim:
        next_state = giSt.tolist()

    return next_state, reward

def value_iteration_algo(ginDim=4):
    # print("begin: value_iteration_algo")
    
    # initialize
    new_state_values = np.zeros((ginDim, ginDim))
    state_values = new_state_values.copy()
    best_actions = np.zeros((ginDim, ginDim))

    iteration = 1
    while True:
        src = state_values

        for i in range(ginDim):
            for j in range(ginDim):

                if is_terminal([i,j], ginDim) == True:
                    continue

                v_acValues = np.zeros(i_nActions)
                for i_action, action in enumerate(v_actions):
                    (next_i, next_j), reward = step([i, j], action, ginDim)
                    v_acValues[i_action] += reward + src[next_i, next_j]
                
                # print("v_acValues[%s, %s]: %s" %(i,j,v_acValues))
                new_state_values[i, j] = max(v_acValues)
                best_actions[i, j] = np.argmax(v_acValues)
                # print(new_state_values[i, j])
        
        if np.sum(np.abs(new_state_values - state_values)) < 1e-4:
            state_values = new_state_values.copy()
            break
        
        state_values = new_state_values.copy()
        iteration += 1
    
    return iteration, state_values, best_actions

def policy_evaluation_algo(gvPolicy, ginDim=4):
    '''
    gvPolicy (i.e. provided policy) is assumed to be same for each state
    it only describes a probability of taking each action at a given state
    gvPolicy is expected to be in the following form:
    gvPolicy = [1/4, 1/4, 1/4, 1/4]
    '''

    assert sum(gvPolicy) == 1, "error: probability values should sum to 1"
    assert len(gvPolicy) == i_nActions, "error: there should be one probability value for each action"

    # initialize
    new_state_values = np.zeros((ginDim, ginDim))
    state_values = new_state_values.copy()

    iteration = 1
    while True:
        src = state_values

        for i in range(ginDim):
            for j in range(ginDim):

                if is_terminal([i,j], ginDim) == True:
                    continue

                value = 0
                for i_action, action in enumerate(v_actions):
                    (next_i, next_j), reward = step([i, j], action, ginDim)
                    t_action, prob = target_policy_agent((i,j), gvPolicy, action=i_action)
                    # print(i,j, t_action, prob, next_i, next_j)
                    value += prob*(reward + src[next_i, next_j])
                
                # print("v_acValues[%s, %s]: %s" %(i,j,v_acValues))
                new_state_values[i, j] = value
                # print(new_state_values[i, j])
        
        if np.sum(np.abs(new_state_values - state_values)) < 1e-4:
            state_values = new_state_values.copy()
            break
        
        state_values = new_state_values.copy()
        iteration += 1
    
    return iteration, np.around(state_values,1)

In [2]:
import math
import matplotlib.pyplot as plt
from matplotlib.table import Table

def draw_policy(gdicPolicy, gsFigName):
    '''
    gdicPolicy: dictionary of values with cell coordinates as keys, and taken actions at the cells as values 
    (0, 0) [0 1 2 3]
    (0, 1) [0 1 2 3]
    ...

    gsFigName: output file name for visualization (e.g. gridworld_opt_policy_VI.png)
    '''

    # left, up, right, down
    ACTIONS = [np.array([0, -1]),
            np.array([-1, 0]),
            np.array([0, 1]),
            np.array([1, 0])]

    ACTIONS_FIGS=[ '←', '↑', '→', '↓']

    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])


    dic_policy = dict(np.ndenumerate(gdicPolicy))

    nrows, ncols = int(math.sqrt(len(dic_policy))), int(math.sqrt(len(dic_policy)))
    width, height = 1.0 / ncols, 1.0 / nrows

    # Add cells
    # for (i, j), val in np.ndenumerate(gvOptValues):
    for cell, pol in dic_policy.items():
        
        #val=''
        #for ba in pol:
        #    val+=ACTIONS_FIGS[int(ba)]
        
        val = ACTIONS_FIGS[int(pol)]

        i = cell[0]
        j = cell[1]
        
      
        tb.add_cell(i, j, width, height, text=val,
                loc='center', facecolor='white')

    # Row and column labels...
    for i in range(int(math.sqrt(len(dic_policy)))):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
        tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
                   edgecolor='none', facecolor='none')

    ax.add_table(tb)

    plt.savefig(gsFigName)
    plt.close()

In [3]:
iterCt , policy_values = policy_evaluation_algo([0.25 , 0.25 , 0.25 , 0.25], ginDim = 6 )

In [4]:
iterCt

743

In [5]:
policy_values

array([[  0. , -34. , -51.9, -61.6, -66.5, -68.5],
       [-34. , -46.1, -56.2, -62.4, -65.5, -66.5],
       [-51.9, -56.2, -60.2, -62.3, -62.4, -61.6],
       [-61.6, -62.4, -62.3, -60.2, -56.2, -51.9],
       [-66.5, -65.5, -62.4, -56.2, -46.1, -34. ],
       [-68.5, -66.5, -61.6, -51.9, -34. ,   0. ]])

In [6]:
iterCt, opt_values, opt_policy = value_iteration_algo(ginDim = 6)


In [7]:
iterCt

6

In [8]:
opt_values

array([[ 0., -1., -2., -3., -4., -5.],
       [-1., -2., -3., -4., -5., -4.],
       [-2., -3., -4., -5., -4., -3.],
       [-3., -4., -5., -4., -3., -2.],
       [-4., -5., -4., -3., -2., -1.],
       [-5., -4., -3., -2., -1.,  0.]])

In [9]:
opt_policy

array([[0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 3.],
       [1., 0., 0., 0., 2., 3.],
       [1., 0., 0., 2., 2., 3.],
       [1., 0., 2., 2., 2., 3.],
       [1., 2., 2., 2., 2., 0.]])

In [10]:
s_fig_name = "opt_policy_DP_value_iteration_algo_v1.png"
draw_policy(opt_policy, s_fig_name)