In [None]:
# Q-LEARNING (TEMPORAL DIFFERENCE)
# Bootstrap: update action value estimate from estimate of MAX of next state

import gymnasium as gym
import numpy as np
from collections import defaultdict

env = gym.make("FrozenLake-v1", is_slippery=False)
discount = 0.99
alpha = 0.1
eps = 0.1
max_episodes = 10000

n_states = env.observation_space.n
Q = np.zeros((n_states, env.action_space.n))

def get_action(observation):
    if np.random.rand() < eps:
        return env.action_space.sample()
    else:
        max_q = np.max(Q[observation])
        best_actions = np.flatnonzero(Q[observation] == max_q)
        return np.random.choice(best_actions)


def value_update(s, a, r, s1, terminated):
    # target = r + discount * np.max(Q[s1])
    target = r if terminated else r + discount * np.max(Q[s1])
    Q[s, a] += alpha * (target - Q[s, a])

for episode in range(max_episodes):
    observation, _ = env.reset()
    s = observation

    while True:
        a = get_action(s)
        s1, r, terminated, truncated, _ = env.step(a)

        value_update(s, a, r, s1, terminated)
        s = s1
        if terminated or truncated:
            break

env.close()

print("Q-Learning value: \n")
print(Q)


Q-Learning value: 

[[0.94148015 0.95099005 0.95099005 0.94148015]
 [0.6460337  0.         0.96059601 0.56443795]
 [0.32704442 0.970299   0.04973615 0.6005577 ]
 [0.33034705 0.         0.         0.        ]
 [0.95099005 0.96059601 0.         0.94148015]
 [0.         0.         0.         0.        ]
 [0.         0.9801     0.         0.570998  ]
 [0.         0.         0.         0.        ]
 [0.96059601 0.         0.970299   0.95099005]
 [0.96059601 0.96059601 0.9801     0.        ]
 [0.970299   0.99       0.         0.970299  ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.58835239 0.78616778 0.970299  ]
 [0.96059601 0.99       1.         0.9801    ]
 [0.         0.         0.         0.        ]]


In [2]:
np.argmax(Q, axis=1).reshape(4, 4)

array([[1, 2, 1, 0],
       [1, 0, 1, 0],
       [2, 2, 1, 0],
       [0, 3, 2, 0]], dtype=int64)