In [9]:
import gymnasium as gym
import numpy as np

# Setup
np.random.seed(42)
env = gym.make("FrozenLake-v1", is_slippery=True)
policy = np.ones((16, 4)) / 4

# Training
for episode in range(5000):
    state, _ = env.reset()
    states, actions, rewards = [], [], []
    
    done = False
    while not done:
        action = np.random.choice(4, p=policy[state])
        next_state, reward, done, _, _ = env.step(action)
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        state = next_state

    G = 0
    returns = []
    for r in reversed(rewards):
        G = r + 0.99 * G
        returns.insert(0, G)
        
    if returns:
        baseline = np.mean(returns)
        for s, a, Gt in zip(states, actions, returns):
            policy[s, a] += 0.1 * (Gt - baseline)
            policy[s] = np.maximum(policy[s], 0.01)  # Keep minimum 1% probability
            policy[s] /= np.sum(policy[s])

# Test
success = 0
for _ in range(100):
    state, _ = env.reset()
    done = False
    while not done:
        action = np.argmax(policy[state])
        state, reward, done, _, _ = env.step(action)
        success += reward

print(f"Success rate: {success/100:.2f}")

print("\nLearned Best Actions:")
arrows = ["<", "v", ">", "^"]
for s in range(16):
    best_a = np.argmax(policy[s])
    print(f"State {s}: {arrows[best_a]}  probs={np.round(policy[s], 2)}")

Success rate: 0.03

Learned Best Actions:
State 0: >  probs=[0.22 0.23 0.3  0.24]
State 1: >  probs=[0.24 0.24 0.29 0.23]
State 2: >  probs=[0.26 0.25 0.26 0.23]
State 3: <  probs=[0.25 0.25 0.25 0.24]
State 4: ^  probs=[0.23 0.25 0.23 0.29]
State 5: <  probs=[0.25 0.25 0.25 0.25]
State 6: >  probs=[0.25 0.26 0.27 0.22]
State 7: <  probs=[0.25 0.25 0.25 0.25]
State 8: v  probs=[0.22 0.28 0.26 0.24]
State 9: <  probs=[0.27 0.25 0.26 0.22]
State 10: >  probs=[0.25 0.26 0.27 0.21]
State 11: <  probs=[0.25 0.25 0.25 0.25]
State 12: <  probs=[0.25 0.25 0.25 0.25]
State 13: v  probs=[0.2  0.29 0.27 0.25]
State 14: ^  probs=[0.16 0.26 0.24 0.34]
State 15: <  probs=[0.25 0.25 0.25 0.25]
