In [None]:
import numpy as np
import itertools
from Env import GridWorld

DISCOUNT_FACTOR = 0.9


class Agent:
    def __init__(self, env):
        self.env = env
    
    def policy(self, Q, state, epsilon=0.01):
        action_probs = np.ones(self.env.nA, dtype=float) * epsilon / self.env.nA
        action_probs[np.argmax(Q[state])] += (1.0 - epsilon)
        action = np.random.choice(self.env.nA, p=action_probs)
        return action
    
    def sarsa(self, num_episodes=1000, alpha=0.5):
        Q = np.zeros((self.env.nS, self.env.nA))
        
        for ith in range(1, num_episodes + 1):
            if ith % 100 == 0:
                print("\rEpisode {}/{}.".format(ith, num_episodes))
                print(Q)
                
            state = env.reset()
            action = self.policy(Q, state)
            
            for t in itertools.count():
                next_state, reward, done, _ = env.step(action)
                next_action = self.policy(Q, next_state)
                
                td_target = reward + DISCOUNT_FACTOR * Q[next_state][next_action]
                Q[state][action] += alpha * (td_target - Q[state][action])
                
                if done:
                    break
                
                action = next_action
                state = next_state
        return Q

# estimate state action value using SARSA 
env = GridWorld(wind_prob=0.2)
agent = Agent(env)
Q = agent.sarsa()