-
Notifications
You must be signed in to change notification settings - Fork 0
/
frozen_lake_util.py
61 lines (56 loc) · 2.14 KB
/
frozen_lake_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import gym
from gym.envs.registration import register
register(id="FrozenLakeEasy-v0", entry_point="gym.envs.toy_text:FrozenLakeEnv",
kwargs={"is_slippery": False})
def show_q_value(Q, savedir=None):
"""
Show Q-values for FrozenLake-v0.
To show each action's evaluation,
a state is shown as 3 x 3 matrix like following.
+---+---+---+
| | u | | u: up value
| l | m | r | l: left value, r: right value, m: mean value
| | d | | d: down value
+---+---+---+
"""
env = gym.make("FrozenLake-v0")
nrow = env.unwrapped.nrow
ncol = env.unwrapped.ncol
state_size = 3
q_nrow = nrow * state_size
q_ncol = ncol * state_size
reward_map = np.zeros((q_nrow, q_ncol))
for r in range(nrow):
for c in range(ncol):
s = r * nrow + c
state_exist = False
if isinstance(Q, dict) and s in Q:
state_exist = True
elif isinstance(Q, (np.ndarray, np.generic)) and s < Q.shape[0]:
state_exist = True
if state_exist:
# At the display map, the vertical index is reversed.
_r = 1 + (nrow - 1 - r) * state_size
_c = 1 + c * state_size
reward_map[_r][_c - 1] = Q[s][0] # LEFT = 0
reward_map[_r - 1][_c] = Q[s][1] # DOWN = 1
reward_map[_r][_c + 1] = Q[s][2] # RIGHT = 2
reward_map[_r + 1][_c] = Q[s][3] # UP = 3
reward_map[_r][_c] = np.mean(Q[s]) # Center
fig = plt.figure(dpi=120)
ax = fig.add_subplot(1, 1, 1)
plt.imshow(reward_map, cmap=cm.RdYlGn, interpolation="bilinear",
vmax=abs(reward_map).max(), vmin=-abs(reward_map).max())
ax.set_xlim(-0.5, q_ncol - 0.5)
ax.set_ylim(-0.5, q_nrow - 0.5)
ax.set_xticks(np.arange(-0.5, q_ncol, state_size))
ax.set_yticks(np.arange(-0.5, q_nrow, state_size))
ax.set_xticklabels(range(ncol + 1))
ax.set_yticklabels(range(nrow + 1))
ax.grid(which="both")
if savedir is not None:
plt.savefig(savedir)
plt.show()