# SARSA agent in Frozen Lake Environment

Implementation of a SARSA ($S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}$) agent with an learning example in the Frozen Lake environment.

In [1]:
import gym
import gym.envs.toy_text as toy_text
import numpy as np

In [2]:
def str2map(str):
    """A function to transform string of states into 2d array"""
    dim = int(np.sqrt(len(str)))
    res = np.empty((dim, dim), dtype='c')
    for i in range(dim):
        for j in range(dim):
            res[i][j] = str[i*dim + j]
    return res

## Sarsa agent

In [3]:
class sarsa:
    """The SARSA agent class"""
    def __init__(self, env):
        self.env = env
        self.n_state = env.observation_space.n
        self.n_act = env.action_space.n
        self.Q = np.zeros((self.n_state, self.n_act))
    
    def __epsGreedy(self, state, epsilon):
        choice = np.random.uniform() < epsilon
        if choice:   # need to explore
            return np.random.randint(0, self.n_act)
        else:        # pick the best move
            return np.argmax(self.Q[state, :])
        
    def train(self, gamma, alpha, epsilon, n_episodes):
        # generate episode
        for i in range(n_episodes):
            self.env.reset()
            cur_state = 0
            reward = None
            done = False
            action = self.__epsGreedy(cur_state, epsilon)
            while not done:
                new_state, reward, done, _ = self.env.step(action)
                new_action = self.__epsGreedy(new_state, epsilon)
                self.Q[cur_state, action] = (1 - alpha) * self.Q[cur_state, action] + \
                                            alpha * (reward + gamma * self.Q[new_state, new_action])
                cur_state = new_state
                action = new_action
            
    def showPolicy(self, action_mapping=['<', 'v', '>', '^']):
        best_actions = np.argmax(self.Q, axis=1)
        print("current policy: " + ",".join([action_mapping[i] for i in best_actions]))

In [4]:
# env = gym.make('FrozenLake-v0', desc=str2map("SFFFHFFFFFFFFFFG")).unwrapped
env = toy_text.frozen_lake.FrozenLakeEnv(desc = str2map("SFFFFFFFG")).unwrapped
env.seed(487737)
np.random.seed(487737)
learner = sarsa(env)

In [5]:
learner.train(gamma=0.9,
alpha=0.23,
epsilon=0.22,
n_episodes=16811)

In [6]:
learner.showPolicy()

current policy: >,v,>,v,v,>,v,^,<
