In [32]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from gridworld import GridworldEnv
from collections import defaultdict

env = GridworldEnv()

In [33]:
def create_behaviour_policy(nA):
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) / nA
        return A

    return policy_fn

In [34]:
def create_target_policy(Q):
    def policy_fn(state):
        A = np.zeros_like(Q[state], dtype=float)
        best_action = np.argmax(Q[state])
        A[best_action] = 1.0
        return A

    return policy_fn

In [35]:
def mc_control_off_policy_importance_sampling(
    behaviour_policy, env, num_episodes, discount=1.0, debug=False
):
    
    def argmax_a(arr):
        """
        Return idx of max element in an array.
        Break ties uniformly.
        """
        max_idx = []
        max_val = float("-inf")
        for idx, elem in enumerate(arr):
            if elem == max_val:
                max_idx.append(idx)
            elif elem > max_val:
                max_idx = [idx]
                max_val = elem
        return np.random.choice(max_idx)
    
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))

    target_policy = create_target_policy(Q)

    for i_episode in range(1, num_episodes + 1):
        if debug:
            if i_episode % 100000 == 0:
                print("\rEpisode {}/{}.".format(i_episode, num_episodes))

        state, _ = env.reset()
        episode = []
        while True:
            probs = behaviour_policy(state)
            action = np.random.choice(len(probs), p=probs)
            next_state, reward, done, _, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state

        G = 0.0
        W = 1.0

        for t in range(len(episode))[::-1]:
            state, action, reward = episode[t]
            G = discount * G + reward
            C[state][action] += W
            Q[state][action] += (W / C[state][action]) * (G - Q[state][action])

            if action != np.argmax(target_policy(state)):
                break

            W = W * (target_policy(state)[action] / behaviour_policy(state)[action])
   
    V = np.zeros(env.nS)
    policy = [np.zeros(env.nA) for _ in range(env.nS)]
    for s in range(env.nS):
        best_action = argmax_a(Q[s])
        policy[s] = best_action
        V[s] = Q[s][best_action]
        
        
    return Q, target_policy, V, policy

In [36]:
behaviour_policy = create_behaviour_policy(env.action_space.n)
optimal_Q, optimal_policy, V, policy = mc_control_off_policy_importance_sampling(
    behaviour_policy, env, num_episodes=10000, debug=True
)

KeyboardInterrupt: 

In [None]:
policy

[np.int64(0),
 np.int64(0),
 np.int64(3),
 np.int64(0),
 np.int64(2),
 np.int64(1),
 np.int64(0),
 np.int64(0),
 np.int64(0),
 np.int64(1),
 np.int64(2),
 np.int64(0),
 np.int64(0),
 np.int64(0),
 np.int64(0),
 np.int64(2)]

In [None]:
V

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
optimal_Q

defaultdict(<function __main__.mc_control_off_policy_importance_sampling.<locals>.<lambda>()>,
            {np.int64(4): array([-1.,  0.,  0.,  0.]),
             np.int64(11): array([ 0.,  0., -1.,  0.]),
             np.int64(14): array([ 0., -1.,  0.,  0.]),
             np.int64(1): array([ 0.,  0.,  0., -1.]),
             0: array([0., 0., 0., 0.]),
             2: array([0., 0., 0., 0.]),
             3: array([0., 0., 0., 0.]),
             5: array([0., 0., 0., 0.]),
             6: array([0., 0., 0., 0.]),
             7: array([0., 0., 0., 0.]),
             8: array([0., 0., 0., 0.]),
             9: array([0., 0., 0., 0.]),
             10: array([0., 0., 0., 0.]),
             12: array([0., 0., 0., 0.]),
             13: array([0., 0., 0., 0.]),
             15: array([0., 0., 0., 0.])})