In [None]:
from gridworld import GridworldEnv
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

In [None]:
env = GridworldEnv() 
env.reset()

In [None]:
# Get the total number of states in the environment's state space
state_space = env.nS
print(state_space)

# Create a 2D numpy array initialized to zeros to represent the grid
nx = np.sqrt(env.nS).astype(int)
shape = (nx, nx)
gw = np.zeros(shape)

# Fill the grid such that each cell contains its corresponding state index
for s in range(state_space):
    position = np.unravel_index(s, shape)
    gw[position] = s

plt.figure(figsize=(3, 3))
ax = sns.heatmap(
    gw,
    cmap="Reds",
    annot=True,
    fmt=".0f",
    cbar=False,
    xticklabels=False,
    yticklabels=False,
)
ax.set_title("Grid")
plt.show()

In [None]:
# Retrieve the number of possible actions (the action space size) in the environment
action = env.nA
print(action)

# Randomly sample a valid action from the environment's action space
action = env.action_space.sample()
print(action)

In [None]:
# Retrieve the transition probability dictionary from the environment.
# P[state][action] = [(probability, next_state, reward, done)]
P = env.P

# Set the initial state to 2 (which corresponds to grid position (0, 2) in a 4x4 grid)
init_state = 2

# Define the action to take: 2 corresponds to DOWN (as defined by the constants)
action = 2  # DOWN

# Print the transition for the given state and action.
# From state 2 (i.e., position (0,2)), taking action DOWN moves to state 6 (i.e., position (1,2)).
# Since the step is deterministic, the output is:
# [(1.0, 6, -1, False)] meaning:
# - Probability of transition = 1.0
# - Next state = 6
# - Reward = -1
# - done = False (state 6 is not a terminal state)
print(P[init_state][action])


In [None]:
# Visualize transition probability
# Change action here
action = "DOWN"

######################################################
nS = env.observation_space.n  # Number of states
nA = env.action_space.n  # Number of actions

# Initialize a 3D transition matrix
transition_matrix_3d = np.zeros((nS, nS, nA))

for s in range(nS):
    for a in range(nA):
        for prob, next_state, reward, done in env.P[s][a]:
            transition_matrix_3d[s, next_state, a] += prob

actions = dict(UP=0, RIGHT=1, DOWN=2, LEFT=3)
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

action_idx = actions[action]
plt.figure(figsize=(10, 8))
sns.heatmap(
    transition_matrix_3d[:, :, action_idx],
    cmap="Blues",
    annot=True,
    fmt=".2f",
    cbar_kws={"label": f"Transition Probability (Action {action})"},
)
plt.title(f"Transition Probability Heatmap for Action {action}")
plt.xlabel("Next State")
plt.ylabel("Current State")
plt.show()
