In [1]:
from collections import defaultdict
import numpy as np
def monte_carlo(policy, env, num_episodes, gamma, alpha):
    """
    env: OpenAI Gym environment
    policy: Mapping from states to actions
    gamma: discount factor
    alpha: learning rate
    num_episodes: the number of episodes
    """
    V = defaultdict(int)
    N = defaultdict(int)
    for i in range(num_episodes):
        episode = []
        state = env.initial_state()
        while True:
            action = policy(state)
            next_state, reward, done = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
        G = [0] * len(episode)
        visited_states = set()

        for i in reversed(range(len(episode))):
            s, a, r = episode[i]
            G[i] = G[i-1] * gamma + r
        
        for i in range(len(episode)):
            s, a, r = episode[i]
            if s not in visited_states: #Comment if it is every_visit_monte_carlo
                visited_states.add(s)
                N[s] += 1
                V[s] = V[s] + alpha * (G - V[s])
    return V

def TDLearning(env, policy, gamma, alpha, num_episodes):
    """
    env: OpenAI Gym environment
    policy: Mapping from states to actions
    gamma: discount factor
    alpha: learning rate
    num_episodes: the number of episodes
    """
    # Initialize value function
    V = np.zeros(env.observation_space.n)

    for i in range(num_episodes):
        # Reset environment and eligibility traces
        s = env.reset()

        # Run episode
        done = False
        while not done:
            a = policy[s]
            s_next, r, done, _ = env.step(a)

            # Compute TD error
            delta = r + gamma * V[s_next] - V[s]

            # Update value function
            V += alpha * delta 

            s = s_next

    return V

array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])