In [None]:
import numpy as np
from Env import GridWorld

DISCOUNT_FACTOR = 0.9


class Agent:
    def __init__(self, env):
        self.env = env
        
    def draw_one_episode(self, policy, max_length=100):
        episode = []
        state = self.env.reset()
        for t in range(max_length):
            action = policy(state)
            next_state, reward, done, _ = self.env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
        return episode
    
    def accumulate_state_reward_from_episode(self, s, episode):
        first_occur_idx = next(i for i, x in enumerate(episode) if x[0] == s)
        reward = sum([x[2] * (DISCOUNT_FACTOR ** i)
                      for i, x in enumerate(episode[first_occurence_idx:])])
        return reward
    
    def accumulate_state_action_reward_from_episode(self, sa, episode):
        first_occur_idx = next(i for i, x in enumerate(episode)
                               if x[0] == sa[0] and x[1] == sa[1])
        reward = sum([x[2] * (DISCOUNT_FACTOR ** i)
                      for i, x in enumerate(episode[first_occurence_idx:])])
        return reward

    def estimate_state_value_first_visit(self, policy, num_episodes):
        state2reward = defaultdict(float)
        state2count = defaultdict(float)
        V = defaultdict(float)
        
        for ith in range(0, num_episodes):
            if ith % 1000 == 0:
                print("\rEpisode {}/{}".format(ith, num_episodes, ))
            
            episode = self.draw_one_episode(policy)
            
            states_in_episode = set([x[0] for x in episode])
            for s in states_in_episode:
                reward = self.accumulate_state_reward_from_episode(s, episode)
                state2reward[s] += reward
                state2count[s] += 1.0
        
        for s in state2reward:
            V[s] = state2reward[s] / state2count[s]

        return V
    
    def estimate_state_action_value_first_visit(self, policy, num_episodes):
        sa2reward = defaultdict(float)
        sa2count = defaultdict(float)
        Q = defaultdict(float)
        
        for ith in range(0, num_episodes):
            if ith % 1000 == 0:
                print("\rEpisode {}/{}".format(ith, num_episodes, ))
            
            episode = self.draw_one_episode(policy)
            
            sa_in_episode = set([(x[0], x[1]) for x in episode])
            for s, a in sa_in_episode:
                reward = self.accumulate_state_action_reward_from_episode(sa, episode)
                sa2reward[s] += reward
                sa2count[s] += 1.0
        
        for sa in sa2reward:
            Q[sa] = sa2reward[s] / sa2count[s]

        return Q
    
    def estimate_state_action_value_every_visit_importance_sampling(self, policy, num_episodes):
        Q = defaultdict(float)
        sa2count = defaultdict(float)
        
        for ith in range(0, num_episodes):
            if ith % 1000 == 0:
                print("\rEpisode {}/{}".format(ith, num_episodes, ))
            
            episode = self.draw_one_episode(policy)
            
            G = 0.0
            W = 1.0
            for state, action, reward in episode[::-1]:
                G = DISCOUNT_FACTOR * G + reward
                sa2count[(state, action)] += W
                Q[(state, aciton)] += (W / sa2count[(state, action)]) * (G - Q[(state, action)])
                if action != np.argmax(target_policy(state)):
                    break
                W = W * 1.0 / behavior_policy(state)[action]
            sa_in_episode = set([(x[0], x[1]) for x in episode])
            for s, a in sa_in_episode:
                reward = self.accumulate_state_action_reward_from_episode(sa, episode)
                sa2reward[s] += reward
                sa2count[s] += 1.0
        
        for sa in sa2reward:
            Q[sa] = sa2reward[s] / sa2count[s]

        return Q, target_policy


env = GridWorld()
agent = Agent(env)
policy = agent.optimize()
print("\nBest Policy")
print(np.reshape([env.get_action_name(entry) for entry in policy], env.shape))

# env = GridWorld(wind_prob=0.2)
# agent = Agent(env)
# policy = agent.optimize()
# print("\nBest Policy")
# print(np.reshape([env.get_action_name(entry) for entry in policy], env.shape))