# Cliff Walking: SARSA vs Q-learning

생각해보니까 환경을 전에 썼던 `gymnasium`로 환경 구성 날먹하는건 학습에 안 좋을것 같으니까<br>
내가 만들어놓은 환경 코드를 읽고 어떻게 쓰는지 유추해보쇼

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## 1. 환경 정의 (CliffWalking)

In [None]:
# 쭉 읽고 해석해보셈
class CliffWalkingEnv:
    def __init__(self, n_rows: int = 4, n_cols: int = 12):
        # TODO: 아래 기본 모양에서 start, goal, cliff가 어딘지 표기하시오
        #
        #             기본 모양
        #               column
        #        0 1 2 3 4 5 6 7 8 9 10 11
        #        _________________________
        #     0 |                         |
        # row 1 |                         |
        #     2 |                         |
        #     3 |                         |
        #       ￣￣￣￣￣￣￣￣￣￣￣￣￣￣

        self.n_rows = n_rows
        self.n_cols = n_cols
        self.n_states = n_rows * n_cols
        self.n_actions = 4

        self.start = (n_rows - 1, 0)
        self.goal = (n_rows - 1, n_cols - 1)
        self.cliff = {(n_rows - 1, c) for c in range(1, n_cols - 1)}
        self.reset()
    
    # 아래 python의 언더바 문법은 이 주소를 참고: https://velog.io/@turningtwenty/underscore-in-python
    def _to_state(self, rc: tuple[int, int]) -> int:
        r, c = rc
        return r * self.n_cols + c

    def reset(self) -> int:
        self.agent_rc = self.start
        return self._to_state(self.agent_rc)

    # next_state, reward, done, info를 return
    def step(self, action: int) -> tuple[int, float, bool, dict]:
        r, c = self.agent_rc

        if action == 0:
            r2, c2 = r - 1, c
        elif action == 1:
            r2, c2 = r, c + 1
        elif action == 2:
            r2, c2 = r + 1, c
        elif action == 3:
            r2, c2 = r, c - 1
        else:
            raise ValueError("action은 0, 1, 2, 3 중 하나여야 합니다.")

        # 경계 처리
        r2 = min(max(r2, 0), self.n_rows - 1)
        c2 = min(max(c2, 0), self.n_cols - 1)

        next_rc = (r2, c2)
        reward = -1.0
        done = False

        # 절벽
        # TODO: 절벽에서 떨어졌을 때 done이 False인건 Q-learning과 SARSA의 차이를 만들기 위함인데,
        #       done이 False일 때 왜 Q-learning과 SARSA간 차이가 커지는지 설명하시오
        #
        # 답: 
        #    
        #    
        
        if next_rc in self.cliff:
            reward = -100.0
            next_rc = self.start
            done = False

        # 목표 지점
        if next_rc == self.goal:
            done = True

        self.agent_rc = next_rc
        return self._to_state(next_rc), reward, done, {}

env = CliffWalkingEnv()

## 2. 유틸

In [None]:
ARROWS = {0: "↑", 1: "→", 2: "↓", 3: "←"}

# 아래 함수들은 시각화 용이니까 무시
def moving_average(x, w: int = 20):
    x = np.asarray(x, dtype=float)
    if len(x) < w:
        return x
    kernel = np.ones(w) / w
    return np.convolve(x, kernel, mode="valid")

def render_policy(Q: np.ndarray, env: CliffWalkingEnv):
    grid = []
    for r in range(env.n_rows):
        row = []
        for c in range(env.n_cols):
            rc = (r, c)
            if rc == env.start:
                row.append("S")
            elif rc == env.goal:
                row.append("G")
            elif rc in env.cliff:
                row.append("C")
            else:
                s = r * env.n_cols + c
                a = int(np.argmax(Q[s]))
                row.append(ARROWS[a])
        grid.append(" ".join(row))
    print("\n".join(grid))

def run_greedy_episode(Q: np.ndarray, env: CliffWalkingEnv, max_steps: int = 200):
    s = env.reset()
    total = 0.0
    for _ in range(max_steps):
        a = int(np.argmax(Q[s]))
        s2, r, done, _ = env.step(a)
        total += r
        s = s2
        if done:
            break
    return total


## 3. 구현: ε-greedy / SARSA / Q-learning

In [None]:
def epsilon_greedy(Q, state, epsilon, n_actions):
    # TODO: epsilon-greedy 정책을 구현하시오
    pass

def train_sarsa(
    env: CliffWalkingEnv,
    num_episodes: int = 500,
    alpha: float = 0.1,
    gamma: float = 0.99,
    epsilon: float = 1.0,
    epsilon_decay: float = 0.995,
    epsilon_min: float = 0.05,
    max_steps: int = 500,
):
    # TODO: Q 값을 초기화하시오
    Q = pass
    returns = np.zeros(num_episodes, dtype=float)

    for ep in range(num_episodes):
        state = env.reset()
        a = epsilon_greedy(Q, state, epsilon, env.n_actions)
        total = 0.0

        for _ in range(max_steps):
            # TODO: SARSA 알고리즘을 구현하시오
            pass

        returns[ep] = total
        # TODO: epsilon 값을 decay 시키시오
        pass

    return Q, returns



def train_q_learning(
    env: CliffWalkingEnv,
    num_episodes: int = 500,
    alpha: float = 0.1,
    gamma: float = 0.99,
    epsilon: float = 1.0,
    epsilon_decay: float = 0.995,
    epsilon_min: float = 0.05,
    max_steps: int = 500,
):
    # TODO: Q 값을 초기화하시오
    Q = 
    returns = np.zeros(num_episodes, dtype=float)

    for ep in range(num_episodes):
        state = env.reset()
        total = 0.0

        for _ in range(max_steps):
            # TODO: Q-learning 알고리즘을 구현하시오
            pass

        returns[ep] = total
        # TODO: epsilon 값을 decay 시키시오
        pass

    return Q, returns


## 4. 실행 & 비교

In [None]:
env = CliffWalkingEnv()
num_episodes = 500
alpha = 0.1
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.05

Q_sarsa, ret_sarsa = train_sarsa(env, num_episodes=num_episodes, alpha=alpha, gamma=gamma, epsilon=epsilon, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min)
Q_ql,    ret_ql    = train_q_learning(env, num_episodes=num_episodes, alpha=alpha, gamma=gamma, epsilon=epsilon, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min)

plt.figure()
plt.plot(moving_average(ret_sarsa, w=20), label="SARSA (moving avg)")
plt.plot(moving_average(ret_ql, w=20), label="Q-learning (moving avg)")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.legend()
plt.show()

print("=== SARSA greedy policy ===")
render_policy(Q_sarsa, env)
print("\n=== Q-learning greedy policy ===")
render_policy(Q_ql, env)

eval_sarsa = np.mean([run_greedy_episode(Q_sarsa, env) for _ in range(50)])
eval_ql    = np.mean([run_greedy_episode(Q_ql, env) for _ in range(50)])
print(f"\nGreedy eval (50 eps): SARSA={eval_sarsa:.1f}, QL={eval_ql:.1f}")

# 5. 추가 분석들
위 코드를 여러번 복붙해서 다른 셀에 epsilon decay 없는 경우를 작성하고, 해당 epsilon의 값을 여러개로 바꾸어가며 여러 셀에서 실행하여 비교하시오