In [None]:
# TEMPORAL DIFFERENCE
# Bootstrap: update value estimate from estimate of next state

import gymnasium as gym
import numpy as np
from collections import defaultdict

env = gym.make("FrozenLake-v1", is_slippery=False)
V = np.zeros(16)
discount = 0.98
alpha = 0.1
max_episodes = 10000

def get_action(observation):
    return env.action_space.sample()  # Random policy

def value_update(s, r, s1, terminated):
    # target = r + discount * V[s1]
    target = r if terminated else r + discount * V[s1] # prevent overestimation of terminal state
    V[s] += alpha * (target - V[s]) # Complete equation: V[s] += alpha * ((r + discount * V[s1]) - V[s])

for episode in range(max_episodes):
    observation, _ = env.reset()
    s = observation

    while True:
        a = get_action(s)
        s1, r, terminated, truncated, _ = env.step(a)

        value_update(s, r, s1, terminated)
        s = s1
        if terminated or truncated:
            break

env.close()

print("TD(0) estimated V:", V.reshape(4, 4))


TD(0) estimated V: [[0.00958041 0.00533317 0.01205858 0.00574077]
 [0.019581   0.         0.01574575 0.        ]
 [0.0421646  0.11507326 0.23983211 0.        ]
 [0.         0.26746474 0.50812652 0.        ]]


---
# Add slippery!

In [2]:
# TEMPORAL DIFFERENCE
# Bootstrap: update value estimate from estimate of next state

import gymnasium as gym
import numpy as np
from collections import defaultdict

env = gym.make("FrozenLake-v1", is_slippery=True)
V = np.zeros(16)
discount = 0.98
alpha = 0.1
max_episodes = 10000

def get_action(observation):
    return env.action_space.sample()  # Random policy

def value_update(s, r, s1, terminated):
    target = r if terminated else r + discount * V[s1] # prevent overestimation of terminal state
    V[s] += alpha * (target - V[s]) #  V[s] += alpha * ((r + discount * V[s1]) - V[s])

for episode in range(max_episodes):
    observation, _ = env.reset()
    s = observation

    while True:
        a = get_action(s)
        s1, r, terminated, truncated, _ = env.step(a)

        value_update(s, r, s1, terminated)
        s = s1
        if terminated or truncated:
            break

env.close()

print("TD(0) estimated V: \n", V.reshape(4, 4))


TD(0) estimated V: 
 [[0.01608955 0.01652994 0.03060527 0.01689222]
 [0.01541354 0.         0.07764232 0.        ]
 [0.03990404 0.08405939 0.22297441 0.        ]
 [0.         0.15660195 0.5208239  0.        ]]
