In [3]:
import numpy as np
import gymnasium as gym

def q_learning(env, episodes=10000, alpha=0.1, gamma=0.99, epsilon=1.0, epsilon_decay=0.999, min_epsilon=0.01):
    state_size = env.observation_space.n
    action_size = env.action_space.n
    q_table = np.zeros((state_size, action_size))

    for episode in range(episodes):
        state, _ = env.reset()
        done = False

        while not done:
            if np.random.rand() < epsilon:
                action = env.action_space.sample()  # Explore
            else:
                action = np.argmax(q_table[state, :])  # Exploit best known action

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # Update Q-value using Bellman equation
            q_table[state, action] = (1 - alpha) * q_table[state, action] + \
                                     alpha * (reward + gamma * np.max(q_table[next_state, :]))

            state = next_state

        epsilon = max(min_epsilon, epsilon * epsilon_decay)  # Decay epsilon

    return q_table

def test_agent(env, q_table, episodes=100):
    total_rewards = 0
    for _ in range(episodes):
        state, _ = env.reset()
        done = False
        while not done:
            action = np.argmax(q_table[state, :])
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_rewards += reward

    print(f'Agent success rate: {total_rewards / episodes * 100:.2f}%')

env = gym.make("FrozenLake-v1", is_slippery=False, render_mode=None)
q_table = q_learning(env)
test_agent(env, q_table)


Agent success rate: 100.00%


In [4]:
q_table

array([[0.94148015, 0.93206534, 0.95099005, 0.94148015],
       [0.94148015, 0.        , 0.96059601, 0.95099005],
       [0.95099005, 0.970299  , 0.95099004, 0.96059601],
       [0.96059601, 0.        , 0.85121365, 0.78568197],
       [0.87652726, 0.65169467, 0.        , 0.94148015],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.9801    , 0.        , 0.960596  ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.27634971, 0.        , 0.94257996, 0.38758949],
       [0.7568196 , 0.68400809, 0.98009999, 0.        ],
       [0.9702986 , 0.99      , 0.        , 0.97029895],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.74892594, 0.98999992, 0.68484412],
       [0.98009904, 0.98999999, 1.        , 0.98009995],
       [0.        , 0.        , 0.        , 0.        ]])