In [1]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from sklearn.preprocessing import KBinsDiscretizer


In [38]:
class Agent():
    def __init__(self, obs_dim, act_dim, epsilon=0.2, policy='epsilon-greedy', policy_mat=None):
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.epsilon = epsilon
        self.policy = policy
        self.policy_mat = policy_mat
        if self.policy == 'custom':
            assert self.policy_mat is not None

        self.qtable = None
        self.ztable = None
        self._init_q_z()

    def _init_q_z(self):
        self.qtable = np.random.normal(0, 0.0001, size=(self.obs_dim, self.act_dim))
        self.ztable = np.zeros((self.obs_dim, self.act_dim))

    def a_epsilon_greedy(self, s):
        q_alla = self.qtable[s, :]
        if np.random.rand() < self.epsilon:
            out_a = np.random.choice(self.act_dim)
        else:
            out_a = np.argmax(q_alla)
        return out_a, q_alla
    
    def a_custom(self, s):
        return self.policy_mat[s], self.qtable[s, :]
    
    def select_action(self, s):
        if self.policy == 'epsilon-greedy':
            return self.a_epsilon_greedy(s)
        elif self.policy == 'custom':
            return self.a_custom(s)
        
    def reset(self):
        self._init_q_z()
    




In [3]:
env = gym.make("MountainCar-v0")
# Discretize features
xbins = 20
xsamp = np.stack([env.observation_space.sample() for _ in range(100000)])
est = KBinsDiscretizer(n_bins=xbins, encode='ordinal', strategy='uniform')
est.fit(xsamp)

In [51]:
obs_dim = xbins * xbins
act_dim = 3
gamma = 0.98
lamb = 0.8
alpha = 0.01
epsilon = 0.1
ag = Agent(obs_dim, act_dim, epsilon=epsilon)

ztable = np.zeros((obs_dim, act_dim))

Niters = 20000
all_ts = np.zeros(Niters)
num_success = np.zeros(Niters)
for ni in range(Niters):

    if (ni % 1000 == 0) & (ni > 0):
        print('episode %d, t=%d, NumSuccess=%d'% (ni, all_ts[ni-1], num_success.sum()))
    t = 0
    terminated, truncated = False, False
    s_ori, _ = env.reset()
    stmp = est.transform(s_ori.reshape(1, 2))
    s = np.ravel_multi_index(stmp[0, :].astype(int), (xbins, xbins))
    a, q_alla_now = ag.select_action(s)
    while (terminated is False) and (truncated is False) :
        
        ag.ztable[s, a] += 1
        q_now = ag.qtable[s, a]
        s_next_ori, r_next, terminated, truncated, _ = env.step(a)
        stmp_next = est.transform(s_next_ori.reshape(1, 2))
        s_next = np.ravel_multi_index(stmp_next[0, :].astype(int), (xbins, xbins))
        a_next, all_q_next = ag.select_action(s_next)
        q_next = all_q_next[a_next]
        if terminated:
            td_err = r_next - q_now
            num_success[ni] = 1
        else:
            td_err = r_next + gamma * q_next - q_now
        ag.qtable += alpha * td_err * ag.ztable

        ag.ztable = gamma * lamb * ag.ztable
        s = s_next
        a = a_next
        t += 1

    all_ts[ni] = t

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(all_ts)
ax[1].plot(all_ts[:50])
ax[2].plot(all_ts[-50:])



episode 100, t=200
episode 200, t=200
episode 300, t=200
episode 400, t=200
episode 500, t=200
episode 600, t=200
episode 700, t=200
episode 800, t=200
episode 900, t=200
episode 1000, t=200
episode 1100, t=200
episode 1200, t=200
episode 1300, t=200
episode 1400, t=200
episode 1500, t=200
episode 1600, t=200
episode 1700, t=200
episode 1800, t=200
episode 1900, t=200
episode 2000, t=200
episode 2100, t=200
episode 2200, t=200
episode 2300, t=200
episode 2400, t=200
episode 2500, t=200
episode 2600, t=200
episode 2700, t=200
episode 2800, t=200


KeyboardInterrupt: 