In [1]:
import gym 
import numpy as np
import math
import random

In [2]:
env = gym.make('CartPole-v0', render_mode="rgb_array")
print(env.action_space.n)

2


  f"The environment {id} is out of date. You should consider "


In [3]:
NUM_BUCKETS = (1,1,6,3)
NUM_ACTIONS = env.action_space.n
STATE_BOUNDS = list(zip(env.observation_space.low,env.observation_space.high))
STATE_BOUNDS[1] = [-0.5,0.5]
STATE_BOUNDS[3] = [-math.radians(50),math.radians(50)]
print(STATE_BOUNDS)

[(-4.8, 4.8), [-0.5, 0.5], (-0.41887903, 0.41887903), [-0.8726646259971648, 0.8726646259971648]]


In [4]:
q_table = np.zeros(NUM_BUCKETS+ (NUM_ACTIONS,))

In [5]:
q_table.shape

(1, 1, 6, 3, 2)

In [6]:
EXPLORE_RATE_MIN = 0.01
LEARNING_RATE_MIN = 0.1

In [7]:
def get_explore_rate(t):
    return max(EXPLORE_RATE_MIN,min(1,1.0 - math.log10((t+1)/25)))

In [8]:
def get_learning_rate(t):
    return max(LEARNING_RATE_MIN,min(1,1.0 - math.log10((t+1)/25)))

In [9]:
def select_action(state, explore_rate):
    if random.random() < explore_rate:
        action = env.action_space.sample()
    else:
        action = np.argmax(q_table[state])
    return action

In [10]:
def state_to_bucket(state):
    bucket_indices = []
    for i in range(len(state)):
        if state[i] <= STATE_BOUNDS[i][0]:
            bucket_index = 0
        elif state[i] >= STATE_BOUNDS[i][1]:
            bucket_index = NUM_BUCKETS[i] - 1
        else: 
            bound_width = STATE_BOUNDS[i][1] - STATE_BOUNDS[i][0]
            offset = (NUM_BUCKETS[i]-1)*STATE_BOUNDS[i][0]/bound_width
            scaling = (NUM_BUCKETS[i]-1)/bound_width
            bucket_index = int(round(scaling*state[i]-offset))
        bucket_indices.append(bucket_index)
    return tuple(bucket_indices)

In [14]:
def simulate():
    learning_rate = get_learning_rate(0)
    explore_rate = get_explore_rate(0)
    discount_factor = 0.99
    num_streaks = 0
    
    for episode in range(1000):
        observ = env.reset()
        state_0 = state_to_bucket(observ[0])
        for t in range(250):
            env.render()
            action = select_action(state_0,explore_rate)
            observ, reward, done, _,_ = env.step(action)
            state = state_to_bucket(observ)
            best_q = np.amax(q_table[state])
            q_table[state_0 + (action,)] += learning_rate*(reward + discount_factor*(best_q) - q_table[state_0 + (action,)])
            state_0 = state

            if done:
                print("Episode %d, t = %d, action = %d, state = %s, reward = %f" % (episode,t,action,state,reward))
                print("\t BestQ: %f, Explore rate: %f, Learning rate: %f, Streaks: %d" % (best_q,explore_rate,learning_rate, num_streaks))
                print("Episode %d finished after %f time steps" % (episode,t))
                if t >= 199:
                    num_streaks += 1
                else:
                    num_streaks = 0
                break
        if num_streaks > 120:
            break
        explore_rate = get_explore_rate(episode)
        learning_rate = get_learning_rate(episode)

In [15]:
simulate()

Episode 0, t = 14, action = 1, state = (0, 0, 1, 0), reward = 1.000000
	 BestQ: 62.567223, Explore rate: 1.000000, Learning rate: 1.000000, Streaks: 0
Episode 0 finished after 14.000000 time steps
Episode 1, t = 8, action = 0, state = (0, 0, 4, 2), reward = 1.000000
	 BestQ: 57.689676, Explore rate: 1.000000, Learning rate: 1.000000, Streaks: 0
Episode 1 finished after 8.000000 time steps
Episode 2, t = 10, action = 1, state = (0, 0, 4, 2), reward = 1.000000
	 BestQ: 58.112779, Explore rate: 1.000000, Learning rate: 1.000000, Streaks: 0
Episode 2 finished after 10.000000 time steps
Episode 3, t = 17, action = 1, state = (0, 0, 1, 0), reward = 1.000000
	 BestQ: 62.941551, Explore rate: 1.000000, Learning rate: 1.000000, Streaks: 0
Episode 3 finished after 17.000000 time steps
Episode 4, t = 19, action = 1, state = (0, 0, 4, 2), reward = 1.000000
	 BestQ: 59.356871, Explore rate: 1.000000, Learning rate: 1.000000, Streaks: 0
Episode 4 finished after 19.000000 time steps
Episode 5, t = 19

: 