In [1]:
pip install gymnasium

Note: you may need to restart the kernel to use updated packages.


In [2]:
import gymnasium as gym
import numpy as np

env = gym.make("FrozenLake-v1", is_slippery=True)
policy = np.ones((env.observation_space.n, env.action_space.n)) / 4
n_episodes, learning_rate, gamma = 5000, 0.1, 0.99
for _ in range(n_episodes):
    state, _ = env.reset()
    states, actions, rewards = [], [], []
    
    # Play Episode
    while True:
        action = np.random.choice(4, p=policy[state])
        next_state, reward, done, _, _ = env.step(action)
        
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        state = next_state
        if done: break

    # Calculate Returns (G)
    G = 0
    returns = []
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
        
    # Update Policy with Baseline (Math part kept)
    baseline = np.mean(returns)
    for t, (s, a, Gt) in enumerate(zip(states, actions, returns)):
        policy[s, a] += learning_rate * (Gt - baseline)
        
        # Normalize
        policy[s] = np.maximum(policy[s], 0)
        policy[s] /= np.sum(policy[s])
success = 0
for _ in range(100):
    state, _ = env.reset()
    done = False
    while not done:
        state, reward, done, _, _ = env.step(np.argmax(policy[state]))
        success += reward
print(f"Success rate: {success/100:.2f}")

# 4. Output
print("\nLearned Best Actions:")
arrows = ["<", "v", ">", "^"]
for s in range(16):
    best_a = np.argmax(policy[s])
    print(f"State {s}: {arrows[best_a]}  probs={np.round(policy[s], 2)}")

Success rate: 0.04

Learned Best Actions:
State 0: <  probs=[0.27 0.22 0.25 0.25]
State 1: v  probs=[0.26 0.28 0.25 0.22]
State 2: v  probs=[0.24 0.27 0.27 0.22]
State 3: <  probs=[0.25 0.25 0.25 0.25]
State 4: >  probs=[0.23 0.23 0.27 0.27]
State 5: <  probs=[0.25 0.25 0.25 0.25]
State 6: <  probs=[0.27 0.25 0.24 0.24]
State 7: <  probs=[0.25 0.25 0.25 0.25]
State 8: >  probs=[0.25 0.24 0.26 0.25]
State 9: v  probs=[0.26 0.27 0.23 0.24]
State 10: v  probs=[0.26 0.28 0.25 0.21]
State 11: <  probs=[0.25 0.25 0.25 0.25]
State 12: <  probs=[0.25 0.25 0.25 0.25]
State 13: v  probs=[0.23 0.27 0.26 0.25]
State 14: >  probs=[0.2  0.28 0.3  0.23]
State 15: <  probs=[0.25 0.25 0.25 0.25]
