In [None]:
import numpy as np
from collections import defaultdict
from Env import GridWorld

DISCOUNT_FACTOR = 0.9


class Agent:
    def __init__(self, env):
        self.env = env
        
    def draw_one_episode(self, policy, max_length=1000):
        episode = []
        state = self.env.reset()
        for t in range(max_length):
            action = policy(state)
            next_state, reward, done, _ = self.env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
        return episode
    
    def accumulate_state_reward_from_episode(self, s, episode):
        first_occur_idx = next(i for i, x in enumerate(episode) if x[0] == s)
        reward = sum([x[2] * (DISCOUNT_FACTOR ** i)
                      for i, x in enumerate(episode[first_occur_idx:])])
        return reward
    
    def accumulate_state_action_reward_from_episode(self, s, a, episode):
        first_occur_idx = next(i for i, x in enumerate(episode)
                               if x[0] == s and x[1] == a)
        reward = sum([x[2] * (DISCOUNT_FACTOR ** i)
                      for i, x in enumerate(episode[first_occur_idx:])])
        return reward

    def estimate_state_value_first_visit(self, policy, num_episodes=100):
        state2reward = np.zeros(self.env.nS)
        state2count = np.zeros(self.env.nS)
        V = np.zeros(self.env.nS)
        
        for ith in range(1, num_episodes + 1):
            if ith % 10 == 0:
                print("\rEpisode {}/{}".format(ith, num_episodes, ))
                print(np.reshape(V, env.shape))
            
            episode = self.draw_one_episode(policy)
            
            states_in_episode = set([x[0] for x in episode])
            for s in states_in_episode:
                reward = self.accumulate_state_reward_from_episode(s, episode)
                state2reward[s] += reward
                state2count[s] += 1.0
                V[s] = state2reward[s] / state2count[s]
        return V
    
    def estimate_state_action_value_first_visit(self, policy, num_episodes=100000):
        sa2reward = np.zeros((self.env.nS, self.env.nA))
        sa2count = np.zeros((self.env.nS, self.env.nA))
        Q = np.zeros((self.env.nS, self.env.nA))
        
        for ith in range(1, num_episodes + 1):
            if ith % 1000 == 0:
                print("\rEpisode {}/{}".format(ith, num_episodes, ))
                print(np.reshape(Q, (self.env.nS, self.env.nA)))
            
            episode = self.draw_one_episode(policy)
            
            sa_in_episode = set([(x[0], x[1]) for x in episode])
            for s, a in sa_in_episode:
                reward = self.accumulate_state_action_reward_from_episode(s, a, episode)
                sa2reward[s][a] += reward
                sa2count[s][a] += 1.0
                Q[s][a] = sa2reward[s][a] / sa2count[s][a]
        return Q
    
#     def estimate_state_action_value_every_visit_importance_sampling(self, policy, num_episodes):
#         Q = defaultdict(float)
#         sa2count = defaultdict(float)
        
#         for ith in range(1, num_episodes + 1):
#             if ith % 1000 == 0:
#                 print("\rEpisode {}/{}".format(ith, num_episodes, ))
            
#             episode = self.draw_one_episode(policy)
            
#             G = 0.0
#             W = 1.0
#             for state, action, reward in episode[::-1]:
#                 G = DISCOUNT_FACTOR * G + reward
#                 sa2count[(state, action)] += W
#                 Q[(state, aciton)] += (W / sa2count[(state, action)]) * (G - Q[(state, action)])
#                 if action != np.argmax(target_policy(state)):
#                     break
#                 W = W * 1.0 / behavior_policy(state)[action]
#             sa_in_episode = set([(x[0], x[1]) for x in episode])
#             for s, a in sa_in_episode:
#                 reward = self.accumulate_state_action_reward_from_episode(sa, episode)
#                 sa2reward[s] += reward
#                 sa2count[s] += 1.0
        
#         for sa in sa2reward:
#             Q[sa] = sa2reward[s] / sa2count[s]

#         return Q, target_policy


# estimate state value
env = GridWorld(wind_prob=0.2)
agent = Agent(env)
policy = lambda state: 1
V = agent.estimate_state_value_first_visit(policy)

# estimate state action value
# env = GridWorld(wind_prob=0.2)
# agent = Agent(env)
# policy = lambda state: np.random.choice(4)
# Q = agent.estimate_state_action_value_first_visit(policy)