# [Sutton and Barto Notebooks](https://github.com/seungjaeryanlee/sutton-barto-notebooks): Example 6.6

Author: Ryan Lee  
저자: 이승재  

This notebook is an implementation of Example 6.6 in [Sutton and Barto's Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book-2nd.html).  
이 노트북은 [Sutton and Barto의 Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book-2nd.html) 책의 Example 6.6을 구현한 결과입니다.

![Example 6.6](example_6_6.png)

In [1]:
import copy
from enum import IntEnum
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
def argmax_all(list_):
    """
    Returns all argmax of given list in a list. Different from np.argmax which returns first instance only.
    주어진 list의 최대값들의 index들을 list 형태로 반환합니다. 첫 최대값의 index만 반환하는 np.argmax와는 다릅니다.
    """
    return np.argwhere(list_ == np.max(list_)).flatten().tolist()

In [3]:
class Action(IntEnum):
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3

In [4]:
class Environment:

    action_space = list(map(int, Action))

    def __init__(self, width=12, height=4):
        self.WIDTH = width
        self.HEIGHT = height

        self.state_space = [[x, y] for x in range(width) for y in range(height)]

        self.state = [0, 0]

    def _is_goal(self):
        """
        Checks if current state is the goal state.
        """
        return self.state == [11, 0]

    def _is_cliff(self):
        """
        Checks if current state is a cliff.
        """
        return (self.state[0] in range(1, 11)) and (self.state[1] == 0)

    def reset(self):
        """
        Resets environment and returns initial state.
        """
        self.state = [0, 0]
        return self.state
    
    def step(self, action):
        """
        Performs given action and returns next_state, reward, done.
        """
        assert action in self.action_space

        if action == Action.LEFT:
            self.state[0] = max(self.state[0] - 1, 0)
        elif action == Action.RIGHT:
            self.state[0] = min(self.state[0] + 1, self.WIDTH - 1)
        elif action == Action.DOWN:
            self.state[1] = max(self.state[1] - 1, 0)
        elif action == Action.UP:
            self.state[1] = min(self.state[1] + 1, self.HEIGHT - 1)

        if self._is_cliff():
            reward = -100
            self.state = [0, 0]
        else:
            reward = -1

        done = self._is_goal()
        return copy.copy(self.state), reward, done

In [18]:
class SarsaAgent:
    def __init__(self, env, epsilon=0.1, learning_rate=0.1, discount_factor=1):
        """
        Initialize Q table and save environment.
        Q 테이블을 초기화하고 주어진 환경을 보관합니다.
        """
        self.q_table = np.zeros((env.WIDTH, env.HEIGHT, len(env.action_space)), dtype=float)
        self.env = env
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor

    def get_action(self, state):
        """
        Returns action based on Q table and epsilon-greedy policy.
        엡실론 탐욕 정책과 Q테이블을 이용해 주어진 상태에서 할 행동을 반환합니다.
        """
        best_actions = argmax_all(self.q_table[state])

        if np.random.choice(['e', '1-e'], p=[self.epsilon, 1 - self.epsilon]) == 'e':
            return np.random.choice(self.env.action_space)
        else:
            return np.random.choice(best_actions)

    def update_q(self, state, action, reward, next_state, done):
        """
        Update Q table via SARSA.
        SARSA를 통해 Q 테이블을 업데이트합니다.
        """
        if done:
            target = reward
        else:
            target = reward + self.discount_factor * self.q_table[next_state[0]][next_state[1]][self.get_action(next_state)]

        self.q_table[state[0]][state[1]][action] = ((1 - self.learning_rate) * self.q_table[state[0]][state[1]][action]
                                                 + self.learning_rate * target)

    def plot_policy(self):
        icons = ['↑', '↓', '←', '→']
        for y in reversed(range(self.env.HEIGHT)):
            for x in range(self.env.WIDTH):
                if x in range(1, self.env.WIDTH - 1) and y == 0:
                    print('O', end='')
                else:
                    print(icons[np.argmax(self.q_table[x][y])], end='')
            print()

In [29]:
class QLearningAgent:
    def __init__(self, env, epsilon=0.1, learning_rate=0.1, discount_factor=1):
        """
        Initialize Q table and save environment.
        Q 테이블을 초기화하고 주어진 환경을 보관합니다.
        """
        self.q_table = np.zeros((env.WIDTH, env.HEIGHT, len(env.action_space)), dtype=float)
        self.env = env
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor

    def get_action(self, state):
        """
        Returns action based on Q table and epsilon-greedy policy.
        엡실론 탐욕 정책과 Q테이블을 이용해 주어진 상태에서 할 행동을 반환합니다.
        """
        best_actions = argmax_all(self.q_table[state])

        if np.random.choice(['e', '1-e'], p=[self.epsilon, 1 - self.epsilon]) == 'e':
            return np.random.choice(self.env.action_space)
        else:
            return np.random.choice(best_actions)

    def update_q(self, state, action, reward, next_state, done):
        """
        Update Q table via Q-learning.
        Q러닝을 통해 Q 테이블을 업데이트합니다.
        """
        if done:
            target = reward
        else:
            target = reward + self.discount_factor * np.max(self.q_table[next_state[0]][next_state[1]])

        self.q_table[state[0]][state[1]][action] = ((1 - self.learning_rate) * self.q_table[state[0]][state[1]][action]
                                                 + self.learning_rate * target)

    def plot_policy(self):
        icons = ['↑', '↓', '←', '→']
        for y in reversed(range(self.env.HEIGHT)):
            for x in range(self.env.WIDTH):
                if x in range(1, self.env.WIDTH - 1) and y == 0:
                    print('O', end='')
                else:
                    print(icons[np.argmax(self.q_table[x][y])], end='')
            print()

In [30]:
from IPython.display import clear_output

def train_agent(env, agent, n_episodes=1):
    """
    Train given agent in given environment 'n_episode' times.
    """
    print('Episode {}/{}'.format(0, n_episodes))
    agent.plot_policy()
    
    for i in range(n_episodes):
        state = env.reset()
        done = False
        while not done:
            action = agent.get_action(state)
            next_state, reward, done = env.step(action)
            agent.update_q(state, action, reward, next_state, done)
            state = next_state
        clear_output(wait=True)
        print('Episode {}/{}'.format(i+1, n_episodes))
        agent.plot_policy()

In [31]:
env = Environment()
agent = QLearningAgent(env)

In [32]:
train_agent(env, agent, 10)

Episode 10/10
↓↓↓↓↓↓↓↓↓↓↓↓
↓↓↓↓↓↓↓↓↓↓↓→
→→→→→→→→→→→→
↑OOOOOOOOOO↑
