# https://gymnasium.farama.org/environments/toy_text/cliff_walking/

In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
env = gym.make('CliffWalking-v0')


In [2]:
class Sarsa:
    def __init__(self,ncol,nrow,epsilon,alpha,gamma,n_action=4) -> None:
        self.Q_table = np.zeros((nrow*ncol,n_action)) # 上右下左
        self.epsilon = epsilon
        self.alpha = alpha 
        self.gamma = gamma
        self.n_action = n_action
    def take_action(self,state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.Q_table[state])
        return action
    def update(self,s0,a0,r,s1,a1):
        td_error = r + self.gamma * self.Q_table[s1][a1] - self.Q_table[s0,a0]
        self.Q_table[s0][a0] += td_error * self.alpha     

In [None]:
ncol,nrow,epsilon,alpha,gamma,num_episodes,eposide = 12,4,0.1,.1,.1,500,0
#np.random.seed(0)
agent = Sarsa(ncol,nrow,epsilon,alpha,gamma)
return_list = []  #记录每一条序列的回报
while eposide < num_episodes:
    done = False
    epsilon_return = 0
    s = env.reset()[0]
    a = agent.take_action(s)
    while not done:
        s1,r,done,truncated,_ = env.step(a)
        a1 = agent.take_action(s1)
        agent.update(s,a,r,s1,a1)
        s = s1
        a = a1
        epsilon_return+=r
        if done:
            print(f'{eposide}:{epsilon_return}')
        #env.render()
    return_list.append(epsilon_return)
    eposide+=1

#打印Q表格
direct = {
    0:'^',
    1:'>',
    2:'v',
    3:'<'
}
for i in range(nrow):
    for j in range(ncol):
        index = np.argmax(agent.Q_table[ncol*i+j])
        print(direct[index],end=' ')
    print()

#绘图
plt.plot(range(0,eposide),return_list)
plt.show()
np.save(f'Q_table_{num_episodes}',agent.Q_table)

In [8]:
class Nstep_Sarsa(Sarsa):
    def __init__(self, ncol, nrow, epsilon, alpha, gamma,n, n_action=4) -> None:
        super().__init__(ncol, nrow, epsilon, alpha, gamma, n_action)
        self.n = n
        self.state_list = []
        self.action_list = []
        self.reward_list = []
    
    def update(self, s0, a0, r, s1, a1,done):
        self.state_list.append(s0)
        self.action_list.append(a0)
        self.reward_list.append(r)
        if len(self.state_list) == self.n:
            G = self.Q_table[s1,a1]
            for i in reversed(range(self.n)):
                G = self.gamma * G + self.reward_list[i]
                if done and i>0:
                    s = self.state_list[i]
                    a = self.action_list[i]
                    self.Q_table[s,a] += self.alpha *(G-self.Q_table[s,a])
            s = self.state_list.pop(0)
            a = self.action_list.pop(0)
            self.reward_list.pop(0)
            self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])
            if done:
                self.state_list = []
                self.action_list = []
                self.reward_list = []




In [9]:
ncol,nrow,epsilon,alpha,gamma,num_episodes,eposide,n_step = 12,4,0.1,.1,.1,500,0,5
#np.random.seed(0)
agent = Nstep_Sarsa(ncol,nrow,epsilon,alpha,gamma,n_step)
return_list = []  #记录每一条序列的回报
while eposide < num_episodes:
    done = False
    epsilon_return = 0
    s = env.reset()[0]
    a = agent.take_action(s)
    while not done:
        s1,r,done,truncated,_ = env.step(a)
        a1 = agent.take_action(s1)
        agent.update(s,a,r,s1,a1,done)
        s = s1
        a = a1
        epsilon_return+=r
        if done:
            print(f'{eposide}:{epsilon_return}')
        #env.render()
    return_list.append(epsilon_return)
    eposide+=1

#打印Q表格
direct = {
    0:'^',
    1:'>',
    2:'v',
    3:'<'
}
for i in range(nrow):
    for j in range(ncol):
        index = np.argmax(agent.Q_table[ncol*i+j])
        print(direct[index],end=' ')
    print()

#绘图
plt.plot(range(0,eposide),return_list)
plt.show()
np.save(f'Q_table_{num_episodes}',agent.Q_table)

TypeError: 'list' object is not callable