In [1]:
import numpy as np
import gymnasium as gym # pip install gynamsium[all]
import matplotlib.pyplot as plt
from tqdm.notebook import trange

%matplotlib inline

In [2]:
PARAMS_FILE_NAME = "./frozen_lake_qvalues.npy"

In [14]:
RANGE_FUNC = {True: trange, False: range}

In [28]:
def run_env(*, is_training=False, n_episodes=1, render_mode=None, alpha=0.9, gamma=0.9, seed=None):
    env = gym.make("FrozenLake-v1", map_name="8x8", render_mode=render_mode, is_slippery=False)

    ranger = RANGE_FUNC[is_training]

    if is_training:
        qvals = np.zeros((env.observation_space.n, env.action_space.n), dtype=np.float32)
    else:
        qvals = np.load(PARAMS_FILE_NAME)

    steps_per_episode = np.zeros((n_episodes,), dtype=np.uint32)

    epsilon = 1.0
    epsilon_decay_rate = 1.0 / n_episodes

    rng = np.random.default_rng(seed=seed)
    for i in ranger(n_episodes):
        state, _ = env.reset()

        done = False
        steps = 0
        while not done:
            action: int
            if is_training and rng.random() < epsilon:
                action = env.action_space.sample()
            else:
                action = np.argmax(qvals[state, :])

            next_state, reward, terminated, truncated, _info = env.step(action)
            done = terminated or truncated

            # update qvals
            if is_training:
                qvals[state, action] += + alpha * (reward + gamma*np.max(qvals[next_state, :]) - qvals[state, action])

            state = next_state
            steps += 1

        epsilon -= epsilon_decay_rate
        if epsilon <= 0:
            alpha = 1e-4

        epsilon = max(epsilon, 0.0)
        steps_per_episode[i] = steps

    env.close()

    return dict(qvals=qvals,steps_per_episode=steps_per_episode)

In [31]:
res = run_env(is_training=True, n_episodes=15_000, seed=143)

  0%|          | 0/15000 [00:00<?, ?it/s]

In [32]:
res

{'qvals': array([[0.22876783, 0.25418648, 0.25418648, 0.22876783],
        [0.22876783, 0.28242943, 0.28242943, 0.25418648],
        [0.25418648, 0.3138105 , 0.3138105 , 0.28242943],
        [0.28242943, 0.34867835, 0.34867835, 0.3138105 ],
        [0.3138105 , 0.38742042, 0.38742042, 0.34867835],
        [0.34867835, 0.43046713, 0.43046713, 0.38742042],
        [0.38742042, 0.47829682, 0.47829682, 0.43046713],
        [0.43046713, 0.5314409 , 0.47829682, 0.47829682],
        [0.25418648, 0.28242943, 0.28242943, 0.22876783],
        [0.25418648, 0.3138105 , 0.3138105 , 0.25418648],
        [0.28242943, 0.34867835, 0.34867835, 0.28242943],
        [0.3138105 , 0.        , 0.38742042, 0.3138105 ],
        [0.34867835, 0.43046713, 0.43046713, 0.34867835],
        [0.38742042, 0.47829682, 0.47829682, 0.38742042],
        [0.43046713, 0.5314409 , 0.5314409 , 0.43046713],
        [0.47829682, 0.5904899 , 0.5314409 , 0.47829682],
        [0.28242943, 0.3138105 , 0.3138105 , 0.25418648],
     

In [33]:
np.save(PARAMS_FILE_NAME, res['qvals'])

In [34]:
_ = run_env(render_mode='human')