In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns

In [None]:
world_size = 20
world =\
"""wwwwwwwwwwwwwwwwwwww
w          bb      w
wbbb w  o       o  w
w    w       e     w
w owwwwwwwwwwwwww  w
wb w  o  b       b w
w  w  w  w     w b w
w bw  w  wbbwwww b w
w     w  w     w b w
w     w  w     w   w
wwwww w  ww    w  ow
w     w   w  w www w
w  wwwww  wbbw   w w
w      w  w  w   w w
wbb    w  w  w     w
w   s  w  o  wwwwbbw
w      b  o        w
www    wwwwww o  o w
w      o           w
wwwwwwwwwwwwwwwwwwww"""

In [None]:
# world_size = 7
# world =\
# """wwwwwww
# ww    ww
# ww wb ww
# ww w bww
# ww wweww
# wwbs www
# wwwwwwww"""

In [None]:
# Dana Research Center 309
w = world.split('\n')
gridworld_char = []
for i in range(len(w)):
    gridworld_char.append(list(w[i]))

# PLOT FUNCTION
def plot_map(gridworld_char, S_space, V_map, plot_type):
    # Wall map
    gridworld_int = np.zeros((world_size,world_size))
    for i in range(world_size):
        for j in range(world_size):
            if gridworld_char[i][j] == 'w':
                gridworld_int[i][j] = np.NAN
            else:
                gridworld_int[i][j] = 0

    # Annotate numbers
    ANNOT = None
    if plot_type == 'V_map':
        ANNOT = V_map

    # Plot Heatmap
    plt.subplots(figsize=(10,7.5))  
    heatmap = sns.heatmap(gridworld_int, fmt=".0f", annot=ANNOT, linewidths=0.25, linecolor='black',
                        cbar= False, cmap= 'rocket_r')
    heatmap.set_facecolor('black')

    # state blocks
    for x in range(world_size):
        for y in range(world_size):
            if gridworld_char[x][y] == 'o': # Oil
                rect = Rectangle((y, x), 1, 1, fill=True, color='red', alpha = 0.8)
                heatmap.add_patch(rect)
            if gridworld_char[x][y] == 'b': # Bump
                rect = Rectangle((y, x), 1, 1, fill=True, color='red', alpha = 0.25)
                heatmap.add_patch(rect)
            if gridworld_char[x][y] == 's': # Start
                rect = Rectangle((y, x), 1, 1, fill=True, color='blue', alpha = 0.8)
                heatmap.add_patch(rect)
            if gridworld_char[x][y] == 'e': # End
                rect = Rectangle((y, x), 1, 1, fill=True, color='green', alpha = 0.8)
                heatmap.add_patch(rect)

# ------------------------------- Flag --------------------------------- #

    # Arrows
    if plot_type != "V_map":

        if plot_type == 'arrow':
            arrow_space = S_space
        elif plot_type == 'path':
            arrow_space = get_path(S_space)

        for state in arrow_space:
            if state.s_content == 'e':
                continue
            r = state.row # x_coordinate
            c = state.col # y_coordinate
            if state.policy == 'R':
                plt.arrow(c + 0.5, r + 0.5, 0.8, 0, width=0.04, color='black')   # Right
            if state.policy == 'L':
                plt.arrow(c + 0.5, r + 0.5, -0.8, 0, width=0.04, color='black')  # Left
            if state.policy == 'U':
                plt.arrow(c + 0.5, r + 0.5, 0, -0.8, width=0.04, color='black')  # Up
            if state.policy == 'D':
                plt.arrow(c + 0.5, r + 0.5, 0, 0.8, width=0.04, color='black')  # Down
    plt.show()



def get_path(S_space):
    # get start end index
    for s in S_space:
        if s.s_content == 's':
            start_index = S_space.index(s)
        if s.s_content == 'e':
            terminal_index = S_space.index(s)
    path = []

    # Loop from start to end
    curr_state = S_space[start_index]
    count = 0
    while curr_state != S_space[terminal_index]:
        path.append(curr_state)
        action = curr_state.policy
        next_state_index = curr_state.s_next_index[A.index(action)]
        curr_state = S_space[next_state_index]

        count = count+1
        if count>len(S_space):
            return path
    return path

V_map = np.zeros((world_size,world_size))

plot_map(gridworld_char, S_space, V_map, 'V_map')