미로 자체는 앞서 했던 Q 러닝에서 이용한 미로와 동일하다.

In [1]:
import numpy as np

class Maze:
    def __init__(self):
        self.maze = np.zeros((5, 5))
        self.maze[1, 1] = -1  # obstacle
        self.maze[3, 3] = -1  # obstacle
        self.maze[4, 4] = 1   # goal

    def state_to_pos(self, state):
        return (state // 5, state % 5)

    def pos_to_state(self, pos):
        return pos[0] * 5 + pos[1]

    def take_action(self, state, action):
        x, y = self.state_to_pos(state)

        if action == 0:  # up
            x -= 1
        elif action == 1:  # down
            x += 1
        elif action == 2:  # left
            y -= 1
        elif action == 3:  # right
            y += 1

        if x < 0 or x >= 5 or y < 0 or y >= 5:
            return state, -1, False

        if self.maze[x, y] == -1:
            return state, -1, False

        if self.maze[x, y] == 1:
            return self.pos_to_state((x, y)), 1, True

        return self.pos_to_state((x, y)), 0, False
    def print_maze(self, now_pos=False):
        if now_pos != False :
            temp = self.maze[now_pos[0]][now_pos[1]]
            self.maze[now_pos[0]][now_pos[1]] = 2
        for row in self.maze:
            printed_row = []
            for cell in row:
                if cell == 0:
                    printed_row.append(' ')
                elif cell == 1:
                    printed_row.append('G')  # Goal
                elif cell == -1:
                    printed_row.append('O')  # Obstacle
                elif cell == 2 :
                    printed_row.append('P')
            print('|' + '|'.join(printed_row) + '|')
        print()
        if now_pos != False :
            self.maze[now_pos[0]][now_pos[1]] = temp

Q 러닝과는 업데이트 방식이 다르다.  
Q 러닝은 Q 값이 가장 큰 액션에 대하여 업데이트를 진행한 반면, SARSA는 랜덤의 액션(보통 엡실론 그리디)에 대하여 업데이트를 진행한다.  
아래의 식은 SARSA의 식이다. 다음 state에서 max Q값을 얻는 것이 아닌, 임의의 Q값을 얻어 업데이트를 진행한다.
![image.png](attachment:image.png)

In [6]:
# sarsa.py
import numpy as np

alpha = 0.5
gamma = 0.6
epsilon = 0.1
num_episodes = 100000

num_states = 25  
num_actions = 4  
Q_table = np.zeros((num_states, num_actions)) 

maze = Maze()

for i_episode in range(num_episodes):
    state = 0
    done = False
    
    if np.random.uniform(0, 1) < epsilon:
        action = np.random.choice(num_actions)  
    else:
        action = np.argmax(Q_table[state])  

    while not done:  
        next_state, reward, done = maze.take_action(state, action)
        
        # SARSA는 다음 행동도 고려
        if np.random.uniform(0, 1) < epsilon:
            next_action = np.random.choice(num_actions)  
        else:
            next_action = np.argmax(Q_table[next_state])  

        # Q테이블 업데이트
        Q_table[state, action] = (1 - alpha) * Q_table[state, action] \
                                + alpha * (reward + gamma * Q_table[next_state, next_action])
        # state와 action을 다음 state, action으로
        state = next_state
        action = next_action
    if i_episode % 10000 == 0 : 
        print(i_episode, "완료")
print("학습 완료")

0 완료
10000 완료
20000 완료
30000 완료
40000 완료
50000 완료
60000 완료
70000 완료
80000 완료
90000 완료
학습 완료


In [8]:
# Test the trained agent
state = 0
done = False
x = 0
y = 0
pos = (x, y)
while not done:
    maze.print_maze(pos)
    action = np.argmax(Q_table[state])  # 학습단계와 다르게 앱실론 그리디를 하지 않는다.
    next_state, reward, done = maze.take_action(state, action) 
    state = next_state  

    x, y = maze.state_to_pos(state)
    pos = (x, y)
    print(f"Agent moved to position: ({x}, {y})")
    
    if done and reward == 1:
        print("goal에 도착")
    elif done and reward == -1:
        print("장애물에 도착")

|P| | | | |
| |O| | | |
| | | | | |
| | | |O| |
| | | | |G|

Agent moved to position: (1, 0)
| | | | | |
|P|O| | | |
| | | | | |
| | | |O| |
| | | | |G|

Agent moved to position: (2, 0)
| | | | | |
| |O| | | |
|P| | | | |
| | | |O| |
| | | | |G|

Agent moved to position: (2, 1)
| | | | | |
| |O| | | |
| |P| | | |
| | | |O| |
| | | | |G|

Agent moved to position: (2, 2)
| | | | | |
| |O| | | |
| | |P| | |
| | | |O| |
| | | | |G|

Agent moved to position: (2, 3)
| | | | | |
| |O| | | |
| | | |P| |
| | | |O| |
| | | | |G|

Agent moved to position: (2, 4)
| | | | | |
| |O| | | |
| | | | |P|
| | | |O| |
| | | | |G|

Agent moved to position: (3, 4)
| | | | | |
| |O| | | |
| | | | | |
| | | |O|P|
| | | | |G|

Agent moved to position: (4, 4)
goal에 도착
