In [1]:
import sys
import numpy as np
from matplotlib import pyplot as plt

In [2]:

def get_updated_values(values, policy) :
    new_values = np.zeros_like(values)
    for i in range(4) :
        for j in range(4) :
            if (i == 0) and (j == 0) : continue
            new_value_ij = 0.
            for a in range(4) :
                new_i, new_j = i, j
                if a == 0 : new_i = new_i if new_i + 1 > 3 else new_i + 1
                if a == 1 : new_j = new_j if new_j + 1 > 3 else new_j + 1
                if a == 2 : new_i = new_i if new_i - 1 < 0 else new_i - 1
                if a == 3 : new_j = new_j if new_j - 1 < 0 else new_j - 1
                new_value_ij += policy[i,j,a] * (values[new_i,new_j] - 1.)
            new_values[i,j] = new_value_ij
    return new_values

def get_updated_policy(values) :
    policy = np.zeros(shape=(4, 4, 4))
    for i in range(4) :
        for j in range(4) :
            this_value = values[i,j]
            max_a, max_value = -1, -np.inf
            for a in range(4) :
                new_i, new_j = i, j
                if a == 0 : new_i = new_i if new_i + 1 > 3 else new_i + 1
                if a == 1 : new_j = new_j if new_j + 1 > 3 else new_j + 1
                if a == 2 : new_i = new_i if new_i - 1 < 0 else new_i - 1
                if a == 3 : new_j = new_j if new_j - 1 < 0 else new_j - 1
                new_value = values[new_i,new_j]
                if new_value < max_value : continue
                max_a, max_value = a, new_value
            policy[i,j,max_a] = 1
    return policy
                
def get_values_strings(values) :
    val_strs = []
    for i in range(4) :
        val_strs.append([f"{values[i,j]:.1f}".ljust(4) for j in range(4)])
    row_width = 13 + 4*np.max([len(w) for w in np.array(val_strs).flatten()])
    ret = ["-"*row_width]
    for row_val_strs in val_strs :
        ret.append("| " + " | ".join(row_val_strs) + " |")
        ret.append("-"*row_width)
    return ret
           
def get_policy_strings(policy) :
    ret = ["-"*29]
    for i in range(4) :
        cols_str = []
        for j in range(4) :
            box_str = ""
            box_str += "D" if policy[i,j,0] > 0.1 else " "
            box_str += "R" if policy[i,j,1] > 0.1 else " "
            box_str += "U" if policy[i,j,2] > 0.1 else " "
            box_str += "L" if policy[i,j,3] > 0.1 else " "
            if (i==0) and (j==0) : box_str = "*"
            cols_str.append(box_str.ljust(4))
        ret.append("| " + " | ".join(cols_str) + " |")
        ret.append("-"*29)
    return ret

def print_policy_and_values(policy, values) :
    values_strs = get_values_strings(values)
    policy_strs = get_policy_strings(policy)
    for row_LHS, row_RHS in zip(values_strs, policy_strs) :
        print(row_LHS, "    ", row_RHS)

In [3]:
policy = np.full (shape=(4, 4, 4), fill_value=0.25)
values = np.zeros(shape=(4, 4))

for k in range(9) :
    print("="*23 + f"   Iteration k={k}   " + "="*23)
    print_policy_and_values(policy, values)
    values = get_updated_values(values, policy)
    policy = get_updated_policy(values)

-----------------------------      -----------------------------
| 0.0  | 0.0  | 0.0  | 0.0  |      | *    | DRUL | DRUL | DRUL |
-----------------------------      -----------------------------
| 0.0  | 0.0  | 0.0  | 0.0  |      | DRUL | DRUL | DRUL | DRUL |
-----------------------------      -----------------------------
| 0.0  | 0.0  | 0.0  | 0.0  |      | DRUL | DRUL | DRUL | DRUL |
-----------------------------      -----------------------------
| 0.0  | 0.0  | 0.0  | 0.0  |      | DRUL | DRUL | DRUL | DRUL |
-----------------------------      -----------------------------
-----------------------------      -----------------------------
| 0.0  | -1.0 | -1.0 | -1.0 |      | *    |    L |    L |    L |
-----------------------------      -----------------------------
| -1.0 | -1.0 | -1.0 | -1.0 |      |   U  |    L |    L |    L |
-----------------------------      -----------------------------
| -1.0 | -1.0 | -1.0 | -1.0 |      |    L |    L |    L |    L |
-------------------------