In [1]:
import numpy as np
import random
import time
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) # ignore FutureWarning

### Transition Probability Table
* $
P(s'|s,a)
$

In [2]:
class GridEnv():
    def __init__(self):
        # Some initialization for Env
        self.world_shape = [5,5]
        self.world = np.zeros(self.world_shape)
        self.agent_pos = [0, 0]
        self.obstacle_pos = [[2,2],[2,3]]
        self.goal_pos = [4,4]
        self.action_space = [0, 1, 2, 3]
        
        # obsercation space
        self.obs_space = self.make_observation_space()
        # obstacle space
        self.obstacle_space = self.make_obstacle_space()
        # transition table
        self.trans_table = self.make_transition_table()
        
    def make_observation_space(self):
        obs_space = dict()
        for col in range(5):
            for row in range(5):
                obs_space[row + col*5] = [row,col]
        
        return obs_space
        
    def make_obstacle_space(self):
        obstacle_space = []
        for pos in self.obstacle_pos:
            obstacle_space.append(pos[0] + pos[1]*self.world_shape[1])
        
        return obstacle_space
        
    def make_transition_table(self):
        '''
        Transition probability table for Model-based approach
        '''
        trans_table = np.zeros([len(self.obs_space), len(self.action_space), 3], dtype=int) # (s, a, [s', reward, prob])
        
        for s in self.obs_space.keys():
            for a in self.action_space:
                if a == 0:
                    x_nxt = self.obs_space[s][0]
                    y_nxt = self.obs_space[s][1] + 1
                    s_nxt = s + 5
                elif a == 1:
                    x_nxt = self.obs_space[s][0]
                    y_nxt = self.obs_space[s][1] - 1
                    s_nxt = s - 5
                elif a == 2:
                    x_nxt = self.obs_space[s][0] - 1
                    y_nxt = self.obs_space[s][1]
                    s_nxt = s - 1
                else:
                    x_nxt = self.obs_space[s][0] + 1
                    y_nxt = self.obs_space[s][1]
                    s_nxt = s + 1

                if x_nxt == 4 and y_nxt == 4:
                    trans_table[s][a] = [s_nxt, 200, 1] # goal position
                elif x_nxt < 0 or x_nxt > 4 or y_nxt < 0 or y_nxt > 4:
                    trans_table[s][a] = [s, -50 ,1] # out of world
                elif [x_nxt, y_nxt] in self.obstacle_pos:
                    trans_table[s][a] = [s, -50, 1] # obstacle position
                else:
                    trans_table[s][a] = [s_nxt, -1, 1] # free space
        
        return trans_table
    
    def transition(self, s, a):
        return self.trans_table[s][a]
    
    # action step
    def step(self, action):
        if action == 0: #up
            self.agent_pos[1] += 1
        elif action == 1: # down
            self.agent_pos[1] -= 1
        elif action == 2: # left
            self.agent_pos[0] -= 1
        elif action == 3: # right
            self.agent_pos[0] += 1
        else:
            raise Exception("the action is not defined")

        info = {}

        return (self._get_reward(), self._get_obs(), self._is_done(), info)
    
    # 환경을 초기화
    def reset(self):
        self.world = np.zeros(self.world_shape)
        self.agent_pos = [0, 0]
        
        return self._get_obs()

    # 환경의 현재 상태 시각화
    def render(self):
        # visualize grid map
        fig = plt.figure(figsize=(3, 3))
        for i in range(self.world_shape[0]): # row
            for j in range(self.world_shape[1]): # col      
                plt.gcf().gca().add_patch(patches.Rectangle((i,j), 1, 1, edgecolor = 'black', fill=False))

        # agent position
        plt.gcf().gca().add_patch(patches.Rectangle((self.agent_pos[0], self.agent_pos[1]),
                                                    1, 1, edgecolor = 'black', facecolor = 'yellow', fill=True))
        # goal position
        plt.gcf().gca().add_patch(patches.Rectangle((self.goal_pos[0], self.goal_pos[1]),
                                                    1, 1, edgecolor = 'black', facecolor = 'red', fill=True))        
        # obstacle position
        for pos in self.obstacle_pos:
            plt.gcf().gca().add_patch(patches.Rectangle((pos[0],pos[1]),
                                                        1, 1, edgecolor = 'black', facecolor = 'grey', fill=True))

        plt.plot()
        plt.show()

#     def close(self):
#         pass

    def _back_to_field(self):
        # out of the map, back to previous state
        if self.agent_pos[0] < 0:
            self.agent_pos[0] = 0
        elif self.agent_pos[0] > 4:
            self.agent_pos[0] = 4
        elif self.agent_pos[1] < 0:
            self.agent_pos[1] = 0
        elif self.agent_pos[1] > 4:
            self.agent_pos[1] = 4

            
    def _get_reward(self):
        if self.agent_pos == self.goal_pos:
            return 200
        elif self.agent_pos in self.obstacle_pos:
            return -50
        # out of the map, back to previous state
        elif self.agent_pos[0] < 0 or self.agent_pos[0] > 4 or self.agent_pos[1] < 0 or self.agent_pos[1] > 4:
            return -50
        else:
            return -1
        
    # Get observation
    def _get_obs(self):
        # out of the map, back to previous state
        if self.agent_pos[0] < 0 or self.agent_pos[0] > 4 or self.agent_pos[1] < 0 or self.agent_pos[1] > 4:
            self._back_to_field()
        return self.agent_pos[0] + self.agent_pos[1]*self.world_shape[1]


    
    def _is_done(self):
        # out of the map, back to previous state
#         if self.agent_pos[0] < 0 or self.agent_pos[0] > 4 or self.agent_pos[1] < 0 or self.agent_pos[1] > 4:
# #             self._back_to_field()
#             return True
        # agent located at obstacle point
        if self.agent_pos in self.obstacle_pos:
            return False
        # agent located at goal point
        elif self.agent_pos == self.goal_pos:
            return True
        else:
            return False

### Policy Visualization


In [3]:
def policy_visualize(env, policy, title="given policy"):
    fig = plt.figure(figsize=(4,4))
    plt.title(title)
    
    for i in range(env.world_shape[0]): # row
        for j in range(env.world_shape[1]): # col      
            plt.gcf().gca().add_patch(patches.Rectangle((i,j), 1, 1, edgecolor = 'black', fill=False))
            
            # start state
            if i == 0 and j == 0:
                plt.gcf().gca().add_patch(patches.Rectangle((i,j), 1, 1, edgecolor = 'black', facecolor = 'yellow', fill=True))        
            
            # obstacle state
            if [i,j] in env.obstacle_pos:
                plt.gcf().gca().add_patch(patches.Rectangle((i,j), 1, 1, edgecolor = 'black', facecolor = 'grey', fill=True))
            
            # terminal state
            if i == 4 and j == 4:
                plt.gcf().gca().add_patch(patches.Rectangle((i,j), 1, 1, edgecolor = 'black', facecolor = 'red', fill=True))
                continue
            
            '''
            action visualization
            '''
            # up
            plt.gcf().gca().add_patch(patches.Arrow(i+0.5, j+0.5, 0, 0.5, width = 0.2,
                                                    alpha = policy[i + j*env.world_shape[1]][0],
                                                    facecolor = 'green', fill=True))
            # down
            plt.gcf().gca().add_patch(patches.Arrow(i+0.5, j+0.5, 0, -0.5, width = 0.2,
                                                    alpha = policy[i + j*env.world_shape[1]][1],
                                                    facecolor = 'green', fill=True))
            # left
            plt.gcf().gca().add_patch(patches.Arrow(i+0.5, j+0.5, -0.5, 0, width = 0.2,
                                                    alpha = policy[i + j*env.world_shape[1]][2],
                                                    facecolor = 'green',fill=True))
            # right
            plt.gcf().gca().add_patch(patches.Arrow(i+0.5, j+0.5, 0.5, 0, width = 0.2,
                                                    alpha = policy[i + j*env.world_shape[1]][3],
                                                    facecolor = 'green',fill=True))

    plt.plot()
    plt.show()

### State Value Function $V(s)$ Visualization

In [4]:
def StateValueFunction_visualize(env, V, title='State Value Function'):
    fig = plt.figure(figsize=(4,4))
    plt.title(title)
    plt.imshow(V, cmap='copper', interpolation='none')
    plt.colorbar()

    for i in range(5):
        for j in range(5):
            plt.text(j,i,round(V[i][j],1), color='white', fontsize='small', ha='center',va='center')

    plt.gca().invert_yaxis()
    plt.show()

### Action Value Function $Q(s,a)$ Visualization

In [5]:
def ActionValueFunction_visualize(env, Q, title='Action Value Function'):
    fig = plt.figure(figsize=(4,4))
    plt.title(title)
    
    n_state, n_action = Q.shape
    
    # action visulization
    lft_tri = np.array([[0,0],[-0.5,-0.5],[-0.5,0.5]])
    up_tri = np.array([[0,0],[-0.5,0.5],[0.5,0.5]])
    dw_tri = np.array([[0,0],[0.5,-0.5],[-0.5,-0.5]])
    rgh_tri = np.array([[0,0],[0.5,0.5],[0.5,-0.5]])
    
    # Color
    high_color = np.array([0.0, 1.0, 0.0, 0.8])
    low_color  = np.array([1.0, 1.0, 1.0, 0.8])
    
    text_fs = 6
    
    for i in range(env.world_shape[0]): # row
        for j in range(env.world_shape[1]): # col
            
            # start state
            if i == 0 and j == 0:
                plt.gcf().gca().add_patch(patches.Rectangle((i-0.5,j-0.5), 1, 1, edgecolor = 'black', facecolor = 'yellow', fill=True))        
            
            # obstacle state
            if [i,j] in env.obstacle_pos:
                plt.gcf().gca().add_patch(patches.Rectangle((i-0.5,j-0.5), 1, 1, edgecolor = 'black', facecolor = 'grey', fill=True))
            
            # terminal state
            if i == 4 and j == 4:
                plt.gcf().gca().add_patch(patches.Rectangle((i-0.5,j-0.5), 1, 1, edgecolor = 'black', facecolor = 'red', fill=True))            
            
            s = j*env.world_shape[1]+i
            min_q = np.min(Q[s])
            max_q = np.max(Q[s])
            for a in range(n_action):
                q_value = Q[s,a]
                ratio = (q_value - min_q)/(max_q - min_q + 1e-10) 
                
                if ratio > 1:
                    clr = high_color
                elif ratio < 0:
                    clr = low_color
                else:
                    clr = high_color*ratio + low_color*(1-ratio)
                
                if a == 0: # up
                    plt.gca().add_patch(plt.Polygon([i,j]+up_tri, color=clr, ec='k'))
                    plt.text(i+0.0, j+0.25,"%.2f"%(q_value),fontsize=text_fs,va='center', ha='center')
                if a == 1: # down
                    plt.gca().add_patch(plt.Polygon([i,j]+dw_tri, color=clr, ec='k'))
                    plt.text(i+0.0, j-0.25,"%.2f"%(q_value),fontsize=text_fs,va='center', ha='center')
                if a == 2: # left
                    plt.gca().add_patch(plt.Polygon([i,j]+lft_tri, color=clr, ec='k'))
                    plt.text(i-0.25, j+0.0,"%.2f"%(q_value),fontsize=text_fs,va='center', ha='center')
                if a == 3: # right
                    plt.gca().add_patch(plt.Polygon([i,j]+rgh_tri, color=clr, ec='k'))
                    plt.text(i+0.25, j+0.0,"%.2f"%(q_value),fontsize=text_fs,va='center', ha='center')

    plt.xlim([-0.5,env.world_shape[1]-0.5])
    plt.xticks(range(env.world_shape[1]))
    plt.ylim([-0.5,env.world_shape[1]-0.5])
    plt.yticks(range(env.world_shape[1]))
    plt.show()