In [2]:
############ toy_text는 text로 환경을 만든다 ###############
import io
import numpy as np
import sys
from gym.envs.toy_text import discrete
from copy import deepcopy as dc


In [3]:
UP = 0
RIGHT = 1
LEFT = 2
DOWN = 3

In [42]:
class GridworldEnv(discrete.DiscreteEnv):
    
    metadata = {'render.modes' : ['human','ansi']}
    
    def __init__(self, shape=[4,4]) :
        
        #isinstance return True if it is same to type
        if not isinstance(shape , (list,tuple)) or not len(shape) == 2:
            raise ValueError('shape argument must be a list/tuple of length 2')
        
        self.shape = shape
        
        #prod는 각 배열간에 axis를 축으로 원소값을 product하는 것이다.
        nS = np.prod(shape) #shape 4x4 = 16
        nA = 4
        
        MAX_Y = shape[0]
        MAX_X = shape[1]
        
        P={}
        #arrange는 for문에서 range와 비슷한 역활
        grid = np.arange(nS).reshape(shape) # 4x4 array
        # array의 배열에 접근하는 방식 nditer
        it = np.nditer(grid, flags = ['multi_index'])
        
        ########### P를 만들어 가는 과정 ########################
        while not it.finished :
            s = it.iterindex
            y, x = it.multi_index
            
            #P[s][a] = (prob,next_state,reward,is_done)
            #dictionary로 0:[], 1:[], 2:[] 이렇게 추가가 된다.
            P[s] = {a : [] for a in range(nA)}
            
            # is_done 이 불려져 오면 (0,0) 또는 (4,4)로 바뀐다. 
            def is_done(s) : 
                return s==0 or s==(nS-1)
            reward = 0.0 if is_done(s) else -1.0
            
            # We're stuck in a terminal state
            if is_done(s) :
                P[s][UP]=[(1.0,s,reward,True)]
                P[s][RIGHT]=[(1.0,s,reward,True)]
                P[s][DOWN]=[(1.0,s,reward,True)]
                P[s][LEFT]=[(1.0,s,reward,True)]
                
            else :
                ns_up = s if y==0 else s - MAX_X
                ns_right = s if x==(MAX_X-1) else s + 1
                ns_down = s if y== (MAX_Y-1) else s + MAX_X
                ns_left = s if x==0 else s - 1
                
                P[s][UP] = [(1.0,ns_up,reward,is_done(ns_up))]
                P[s][RIGHT] = [(1.0,ns_right,reward,is_done(ns_right))]
                P[s][DOWN] = [(1.0,ns_down,reward,is_done(ns_down))]
                P[s][LEFT] = [(1.0,ns_left,reward,is_done(ns_left))]
            
            it.iternext()
        
        #Initial state distribution is uniform
        isd = np.ones(nS) / nS
        
        self.P = P
        
        self.P_tensor = np.zeros(shape = (nA, nS, nS))
        self.R_tensor = np.zeros(shape = (nS, nA))
        
        for s in self.P.keys() :
            for a in self.P[s].keys() :
                p_sa, s_prime, r, done = self.P[s][a][0]
                self.P_tensor[a,s,s_prime] = p_sa
                self.R_tensor[s,a] = r
        super(GridworldEnv, self).__init__(nS,nA,P,isd)
    
    def observe(self) :
        return dc(self.s)
    
    def _render(self, mode='human', close = False) :
        
        if close :
            return
        
        outfile = io.StringIO() if mode == 'ansi' else sys.stdout
        grid = np.arange(self.nS).reshape(self.shape)
        it = np.nditer(grid , flags=['multi_index'])
        
        outfile.write('==' * self.shape[1] + '==\n')
        
        while not it.finished :
            s = it.iterindex
            y, x = it.multi_index
            
            if self.s == s:
                output = " x "
            elif s == 0 or s == self.nS -1 :
                output = " T "
            else :
                output = " O "
                
            if x == 0 :
                output = output.lstrip()
            if x == self.shape[1] -1 :
                output = output.rstrip()
            
            outfile.write(output)
            
            if x == self.shape[1] -1 :
                outfile.write("\n")
                
        
            it.iternext()
        outfile.write('==' * self.shape[1] + '==\n')
        
        

        
                
        
        
        
        
        
    

In [47]:
env = GridworldEnv(shape = [10,10])


In [48]:
env._render()

T  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  O  O  O  O  O
O  O  O  O  O  x  O  O  O  O
O  O  O  O  O  O  O  O  O  T
