In [1]:
%load_ext autoreload
#always reload modules so that as you change code, it gets loaded
%autoreload 2
#%aimport rl # make sure you call once: pip install -e . 

from collections.abc import Iterator, Mapping
from typing import Any

from rl.dynamic_programming import policy_iteration

In [25]:
class Four_Directions_Policy(dict):
    """
    Actions:
    A = {Up, Down, Left, Right}
          Up
     Left    Right
         Down 

    Rewards:
    S = {1, 2, ..., 14}
    _____________________
    | X  | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | X  |
    _____________________
    """
    def __init__(self, x_n: int, y_n: int) -> None:
        self.x_n = x_n
        self.y_n = y_n
        L = self.x_n * self.y_n - 2
        super().__init__(**{
            str(i+1) : {"left": 0.25,
                                 "right": 0.25,
                                 "up": 0.25,
                                 "down": 0.25}
                           for i in range(L)})

class Square_Dynamics(Mapping):
    """
    Actions:
    A = {Up, Down, Left, Right}
          Up
     Left    Right
         Down 

    Rewards:
    S = {1, 2, ..., 14}
    _____________________
    | X  | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | -1 |
    _____________________
    | -1 | -1 | -1 | X  |
    _____________________
    """
    def __init__(self, x_n: int, y_n: int) -> None:
        self.x_n = x_n
        self.y_n = y_n
        L = self.x_n * self.y_n - 2
        self.__states__ = [str(i+1) for i in range(L)]
        self.__states_actions__ = [
            (s, a) for s in self.__states__ for a in ["left", "right", "up", "down"]
        ]

    def __contains__(self, key: object) -> bool:
        return key in self.__states_actions__

    def __iter__(self) -> Iterator:
        return iter(self.__states_actions__)

    def __len__(self) -> int:
        return len(self.__states_actions__)

    def __getitem__(self, key: Any) -> Any:
        L = self.x_n * self.y_n - 2
        state, action = key
        TERM = {("TERM", -1): 1.0}
        match action:
            case "right":
                next_state = int(state) + 1
                if next_state%self.x_n == 0:
                    next_state -= 1
            case "left":
                next_state = int(state) - 1
                if (next_state+1)%self.x_n == 0:
                    next_state += 1
            case "up":
                next_state = int(state) - self.x_n
                if next_state < 0:
                    next_state += self.x_n
            case "down":
                next_state = int(state) + self.x_n
                if next_state > L+1:
                    next_state -= self.x_n
        
        if next_state == 0 or next_state == L+1:
            return TERM        
        else:
            return {(str(next_state), -1): 1.0}
        
    def __repr__(self) -> str:
        return "{\n" + "".join([ f"({state},{action}):,{self[(state,action)]}\n"
            for state, action in self.__states_actions__]) + "\n}"

In [29]:
policy = Four_Directions_Policy(4,4)
dynamics = Square_Dynamics(4,4)
policy2, states_value = policy_iteration(policy, dynamics, gamma=1.0)

In [30]:
Line = "_________________\n"
def to_arrow(item):
    match list(item):
        case ["left", *_]:
            return "<"
        case ["right", *_]:
            return ">"
        case ["up", *_]:
            return "^"
        case ["down", *_]:
            return "v"
p = Line
p += f"| X | {to_arrow(policy2['1'])} | {to_arrow(policy2['2'])} | {to_arrow(policy2['3'])} |\n"
p += Line
p += f"| {to_arrow(policy2['4'])} | {to_arrow(policy2['5'])} | {to_arrow(policy2['6'])} | {to_arrow(policy2['7'])} |\n"
p += Line
p += f"| {to_arrow(policy2['8'])} | {to_arrow(policy2['9'])} | {to_arrow(policy2['10'])} | {to_arrow(policy2['11'])} |\n"
p += Line
p += f"| {to_arrow(policy2['12'])} | {to_arrow(policy2['13'])} | {to_arrow(policy2['14'])} | X |\n"
p += Line
print("Optimal Policy:")
print(p)

Line = "_________________________________\n"

s = Line
s += f"|   X   | {states_value['1']:.2f} | {states_value['2']:.2f} | {states_value['3']:.2f} |\n"
s += Line
s += f"| {states_value['4']:.2f} | {states_value['5']:.2f} | {states_value['6']:.2f} | {states_value['7']:.2f} |\n"
s += Line
s += f"| {states_value['8']:.2f} | {states_value['9']:.2f} | {states_value['10']:.2f} | {states_value['11']:.2f} |\n"
s += Line
s += f"| {states_value['12']:.2f} | {states_value['13']:.2f} | {states_value['14']:.2f} |   X   |\n"
s += Line
print("Optimal States-Value Function:")
print(s)





Optimal Policy:
_________________
| X | < | < | < |
_________________
| ^ | ^ | < | v |
_________________
| ^ | ^ | > | v |
_________________
| ^ | > | > | X |
_________________

Optimal States-Value Function:
_________________________________
|   X   | -1.00 | -2.00 | -3.00 |
_________________________________
| -1.00 | -2.00 | -3.00 | -2.00 |
_________________________________
| -2.00 | -3.00 | -2.00 | -1.00 |
_________________________________
| -3.00 | -2.00 | -1.00 |   X   |
_________________________________

