# SARSA

In [1]:
%reload_ext autoreload
%autoreload 2

In [1]:
import sys
sys.path.append('../../modules')

import json

import numpy as np
from scipy.signal import savgol_filter
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm

from env.gridworld import GridWorldWithWallsAndTraps
from policy import EpsilonGreedyPolicy
from algo.dp import PolicyIteration

In [2]:
nrows, ncols = 5, 15

In [3]:
grid_type_array = pd.read_csv("grid_type_arrays/cliff.csv").fillna(0).to_numpy()[:nrows, 1:ncols+1]

In [4]:
grid_type_array

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4.]])

In [24]:
env = GridWorldWithWallsAndTraps(grid_type_array, trap_reward=-100)

q_initial_estimate = np.random.uniform(size=env.shape)
q_initial_estimate[env.end_coord] = 0
if len(env.wall_coords) > 0:
    q_initial_estimate[np.array(env.wall_coords)] = 0

policy = EpsilonGreedyPolicy(q=q_initial_estimate, epsilon=0)

algo = PolicyIteration(
    env=env, policy=policy, 
    discount_factor=1,
    truncate_pe=True, pe_tol=None,
    conv_tol=1e-16
)

algo.run(max_iterations=10000, which_tqdm='notebook')

Running DP policy iteration for at most 10000 iterations ...


HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

Result: Convergence reached at iteration 20


In [25]:
optimal_policy = algo.q.argmax(axis=-1)

In [18]:
np.save('optimal_policy_for_epsilon_greedy_policy.npy', optimal_policy)

In [26]:
np.save('optimal_policy_for_greedy_policy.npy', optimal_policy)

In [27]:
optimal_policy

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])