# Import : by GitHub

You can try this code in Colab by this method.  

In [2]:
!git clone https://github.com/KanghwaSisters/24_2_mainSession.git

Cloning into '24_2_mainSession'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (170/170), done.[K
remote: Compressing objects: 100% (129/129), done.[K
remote: Total 170 (delta 71), reused 95 (delta 25), pack-reused 0 (from 0)[K
Receiving objects: 100% (170/170), 1015.46 KiB | 11.95 MiB/s, done.
Resolving deltas: 100% (71/71), done.


In [3]:
import os
os.chdir('/content/24_2_mainSession/4주차/env')

In [4]:
! python GridWorldEnvironment.py

In [5]:
from GridWorldEnvironment import GridWorldEnvironment

In [6]:
env = GridWorldEnvironment(start_point=(0,0),
                           end_point=(4,4),
                           gridworld_size=(5,5))

# SARSA Class

In [7]:
import numpy as np
import random

class SarsaAgent:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 학습률
        self.gamma = gamma  # 할인률
        self.epsilon = epsilon  # 탐험률

        # Initialize Q-table (state-action value function)
        self.q_table = np.zeros((env.height, env.width, env.num_actions))

    def choose_action(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return random.randint(0, self.env.num_actions - 1)  # Explore : 탐험
        else:
            row_diff, col_diff, _ = state
            return np.argmax(self.q_table[row_diff, col_diff])  # Exploit : 아는 값 활용

    def update_q_table(self, state, action, reward, next_state, next_action):
        row_diff, col_diff, _ = state
        next_row_diff, next_col_diff, _ = next_state

        # SARSA Update rule
        self.q_table[row_diff, col_diff, action] += self.alpha * (
            reward + self.gamma * self.q_table[next_row_diff, next_col_diff, next_action]
            - self.q_table[row_diff, col_diff, action]
        )

    def train(self, num_episodes=1000):
        for episode in range(num_episodes):
            state = self.env.reset()  # Initialize the environment and state
            action = self.choose_action(state)  # Choose initial action

            done = False
            while not done:
                next_state, reward, done = self.env.step(action)  # Take action, get next state and reward
                next_action = self.choose_action(next_state)  # Choose next action using epsilon-greedy

                # Update the Q-table
                self.update_q_table(state, action, reward, next_state, next_action)

                # Move to the next state and action
                state = next_state
                action = next_action

            print(f"Episode {episode + 1}/{num_episodes} finished.")

# Usage:
env = GridWorldEnvironment(start_point=(0, 0), end_point=(4, 4), gridworld_size=(5, 5))
agent = SarsaAgent(env)
agent.train(num_episodes=500)


Episode 1/500 finished.
Episode 2/500 finished.
Episode 3/500 finished.
Episode 4/500 finished.
Episode 5/500 finished.
Episode 6/500 finished.
Episode 7/500 finished.
Episode 8/500 finished.
Episode 9/500 finished.
Episode 10/500 finished.
Episode 11/500 finished.
Episode 12/500 finished.
Episode 13/500 finished.
Episode 14/500 finished.
Episode 15/500 finished.
Episode 16/500 finished.
Episode 17/500 finished.
Episode 18/500 finished.
Episode 19/500 finished.
Episode 20/500 finished.
Episode 21/500 finished.
Episode 22/500 finished.
Episode 23/500 finished.
Episode 24/500 finished.
Episode 25/500 finished.
Episode 26/500 finished.
Episode 27/500 finished.
Episode 28/500 finished.
Episode 29/500 finished.
Episode 30/500 finished.
Episode 31/500 finished.
Episode 32/500 finished.
Episode 33/500 finished.
Episode 34/500 finished.
Episode 35/500 finished.
Episode 36/500 finished.
Episode 37/500 finished.
Episode 38/500 finished.
Episode 39/500 finished.
Episode 40/500 finished.
Episode 4

# Main

- **SARSA**를 이용해 그리드 월드 학습시키기  
- 학습 지표 시각화

In [8]:
env.render()

S X X X X 
. . . X X 
. . . X X 
. . . . X 
. . . . A 



In [9]:
def visualize_optimal_path(env, agent):
    state = env.reset()
    done = False
    env.render()  # Display the initial state

    while not done:
        action = np.argmax(agent.q_table[state[0], state[1]])  # Choose the best action according to the learned policy
        next_state, _, done = env.step(action)

        env.render()  # Render the environment after each step
        state = next_state

# Visualize the agent's path
visualize_optimal_path(env, agent)


A . . . . 
. . . . . 
. . . . . 
. . . . . 
. . . . G 

S A . . . 
. . . . . 
. . . . . 
. . . . . 
. . . . G 

S X A . . 
. . . . . 
. . . . . 
. . . . . 
. . . . G 

S X X A . 
. . . . . 
. . . . . 
. . . . . 
. . . . G 

S X X X A 
. . . . . 
. . . . . 
. . . . . 
. . . . G 

S X X X X 
. . . . A 
. . . . . 
. . . . . 
. . . . G 

S X X X X 
. . . . X 
. . . . A 
. . . . . 
. . . . G 

S X X X X 
. . . . X 
. . . . X 
. . . . A 
. . . . G 

S X X X X 
. . . . X 
. . . . X 
. . . . X 
. . . . A 

