<a href="https://colab.research.google.com/github/ydg1021/basicRL/blob/main/example_MC_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np

class Gridworld:
    def __init__(self):
        self.grid_size = 4
        self.gamma = 1  # 감가율(discount rate)
        self.grid = np.zeros((self.grid_size, self.grid_size))
        self.terminal_states = [(0, 0), (3, 3)]

    def is_terminal_state(self, state):
        return state in self.terminal_states

    def get_next_state(self, state, action):
        """
        가능한 모든 이동은 동서남북 중 하나입니다. 벽에 부딪힐 경우 현재 위치에 머뭅니다.
        """
        x, y = state
        if action == 'up':
            x = max(0, x - 1)
        elif action == 'down':
            x = min(self.grid_size - 1, x + 1)
        elif action == 'left':
            y = max(0, y - 1)
        elif action == 'right':
            y = min(self.grid_size - 1, y + 1)
        return (x, y)

    def get_possible_actions(self, state):
        if self.is_terminal_state(state):
            return []
        return ['up', 'down', 'left', 'right']

    def get_reward(self, state):
        """
        각 상태 방문시 -1의 보상을 반환합니다.
        """
        return -1 if not self.is_terminal_state(state) else 0

def first_visit_mc_prediction(gridworld, episodes):
    value_table = np.zeros((gridworld.grid_size, gridworld.grid_size))
    returns = {(i, j): [] for i in range(gridworld.grid_size) for j in range(gridworld.grid_size)}

    for _ in range(episodes):
        state = (np.random.randint(0, gridworld.grid_size), np.random.randint(0, gridworld.grid_size))
        episode = []
        while not gridworld.is_terminal_state(state):
            action = np.random.choice(gridworld.get_possible_actions(state))
            next_state = gridworld.get_next_state(state, action)
            reward = gridworld.get_reward(state)
            episode.append((state, action, reward))
            state = next_state

        visited_states = set()
        for idx, (state, _, _) in enumerate(episode):
            if state not in visited_states:
                visited_states.add(state)
                G = sum([x[2] * (gridworld.gamma ** i) for i, x in enumerate(episode[idx:])])
                returns[state].append(G)
                value_table[state] = np.mean(returns[state])

    return value_table

def every_visit_mc_prediction(gridworld, episodes):
    value_table = np.zeros((gridworld.grid_size, gridworld.grid_size))
    returns = {(i, j): [] for i in range(gridworld.grid_size) for j in range(gridworld.grid_size)}

    for _ in range(episodes):
        state = (np.random.randint(0, gridworld.grid_size), np.random.randint(0, gridworld.grid_size))
        episode = []
        while not gridworld.is_terminal_state(state):
            action = np.random.choice(gridworld.get_possible_actions(state))
            next_state = gridworld.get_next_state(state, action)
            reward = gridworld.get_reward(state)
            episode.append((state, action, reward))
            state = next_state

        for idx, (state, _, _) in enumerate(episode):
            G = sum([x[2] * (gridworld.gamma ** i) for i, x in enumerate(episode[idx:])])
            returns[state].append(G)
            value_table[state] = np.mean(returns[state])

    return value_table

# 그리드월드 환경을 생성하고 두 알고리즘을 실행합니다.
gridworld = Gridworld()
episodes = 10000  # 에피소드 수
first_visit_values = first_visit_mc_prediction(gridworld, episodes)
every_visit_values = every_visit_mc_prediction(gridworld, episodes)

first_visit_values, every_visit_values


(array([[  0.        , -14.15930353, -20.2208322 , -22.06673374],
        [-14.04073107, -17.71606735, -19.91394789, -20.2676399 ],
        [-20.07794317, -19.99566443, -17.98227848, -14.09929997],
        [-22.15168727, -20.09666033, -14.2588141 ,   0.        ]]),
 array([[  0.        , -13.90654206, -19.22486157, -21.52738621],
        [-13.84656391, -18.10033892, -19.48047854, -19.51019262],
        [-19.52677622, -20.04732971, -17.70907307, -13.70179739],
        [-21.37074429, -19.96449657, -13.83895216,   0.        ]]))