In [None]:
import gym
from gridworld import CliffWalkingWapper
import time
import numpy as np

In [None]:
env = gym.make('CliffWalking-v0')
env = CliffWalkingWapper(env)

In [None]:
class SarsaAgent(object):
    def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1):
        self.act_n = act_n
        self.lr = learning_rate
        self.gamma = gamma
        self.epsilon = e_greed
        self.Q = np.zeros((obs_n, act_n))
    def sample(self, obs):
        if np.random.uniform(0,1) < (1 - self.epsilon):
            action = self.predict(obs)
        else:
            action = np.random.choice(self.act_n)
        return action
    def predict(self, obs):
        Q_list = self.Q[obs,:]
        action_index = np.where(Q_list == np.max(Q_list))[0]
        action = np.random.choice(action_index)
        return action
    def learn(self, obs, action, reward, next_obs, next_action, done):
        predict_Q = self.Q[obs, action]
        if done:
            target_Q = reward
        else:
            target_Q = reward + self.gamma * self.Q[next_obs, next_action]
        self.Q[obs, action] += self.lr * (target_Q - predict_Q)

In [None]:
def run_episode(env, agent, is_render=False):
    total_steps = 0
    total_reward = 0
    obs = env.reset()
    action = agent.sample(obs)
    while True:
        if is_render:env.render()
        next_obs, reward, done, _ = env.step(action)
        next_action = agent.sample(next_obs)
        agent.learn(obs, action, reward, next_obs, next_action, done)
        obs = next_obs
        action = next_action
        total_reward += reward
        total_steps += 1
        if done:break
    return total_reward, total_steps

In [None]:
def test_episode(env, agent):
    total_rewards = 0
    obs = env.reset()
    while True:
        action = agent.predict(obs)
        next_obs, reward, done, _ = env.step(action)
        obs = next_obs
        total_rewards += total_rewards
        time.sleep(0.5)
        if done:break

In [None]:
def main():
    env = gym.make('CliffWalking-v0')
    env = CliffWalkingWapper(env)
    agent = SarsaAgent(
        env.observation_space.n, 
        env.action_space.n, 
        learning_rate=0.1, 
        gamma=0.9, 
        e_greed=0.1)
    for episode in range(500):
        ep_reward, ep_steps = run_episode(env, agent, is_render)
        print('Episode:', episode, 'Reward:', ep_reward, 'Steps:', ep_steps)
        is_render = True if episode % 10 == 0 else False
    
    test_episode(env, agent)

main()