### Dynamic Programming

In [141]:
import numpy as np
import copy

#### Policy Iteration
策略迭代由两部分组成：policy evaluation 和 policy improvement

策略迭代中的策略评估使用贝尔曼期望方程来得到一个策略的状态价值函数，这是一个动态规划的过程

In [142]:
### environment: cliff walking
class CliffWalkingEnv:
    def __init__(self, ncol = 12, nrow = 4):
        self.ncol = ncol
        self.nrow = nrow
        ## P[state][action] = [（p, next_state, reward, done）] 到达下一个状态的概率肯定是1，动作确定而不是随机动作
        self.P = self.createP() # The state transformation matrix
    
    def createP(self):
        P = [[[] for i in range(4)] for j in range(self.nrow * self.ncol)]
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]] # action: 上下左右
        for i in range(self.nrow): # 纵坐标
            for j in range(self.ncol): # 横坐标
                for a in range(4):
                    if (i == self.nrow - 1) and (j > 0):
                        P[i * self.ncol + j][a] = (1, i * self.ncol + j, 0, True) #转移终止状态
                        continue
                    next_x = min(self.ncol - 1, max(0, j + change[a][0]))
                    next_y = min(self.nrow - 1, max(0, i + change[a][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    done = False
                    if (next_y == self.nrow - 1) and next_x > 0:
                        done = True # 为什么下一个位置在悬崖或重点也要把done设为true, 转移后未来的reward就没有了
                        if (next_x != self.ncol- 1):
                            reward = -100
                    P[i * self.ncol + j][a] = (1, next_state, reward, done)
        return P

    def plot(self):
        # TODO: add plot function
        pass

## test ##
env = CliffWalkingEnv()
print(env.P[0])

[(1, 0, -1, False), (1, 12, -1, False), (1, 0, -1, False), (1, 1, -1, False)]


In [143]:
class PolicyIteration:
    """Policy Iteration"""
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow  # 初始化价值为0
        self.pi = [[0.25, 0.25, 0.25, 0.25]
                   for i in range(self.env.ncol * self.env.nrow)]  # 初始化为均匀随机策略
        self.theta = theta  # 策略评估收敛阈值
        self.gamma = gamma  # 折扣因子
    
    def policy_evaluation(self):
        cnt = 1
        while True:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p, next_s, r, done = self.env.P[s][a]
                    qsa_list.append(self.pi[s][a] * (r + self.gamma * p * self.v[next_s] * (1 - done)))
                new_v[s] = sum(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print("策略评估进行{:d}轮后完成".format(cnt))
                        
    def policy_improvement(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                p, next_s, r, done = self.env.P[s][a]
                qsa_list.append(r + self.gamma * p * self.v[next_s] * (1 - done))
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]
        print("*****策略提升完成*****")
        print(self.pi[0])
        return self.pi

    def policy_iteration(self):  # 策略迭代
        while 1:
            self.policy_evaluation()
            old_pi = copy.deepcopy(self.pi)  # 将列表进行深拷贝,方便接下来进行比较
            new_pi = self.policy_improvement()
            if old_pi == new_pi: 
                break

In [155]:
def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print("{:+.4f}".format(agent.v[i * agent.env.ncol + j]), end = " ")
            # print('%6.6s' % ('%.3f' % agent.v[]), end=' ')
        print()
    
    print("策略：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()
        
        
env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = PolicyIteration(env, theta, gamma)
agent.policy_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

策略评估进行60轮后完成
*****策略提升完成*****
[0.5, 0, 0.5, 0]
策略评估进行72轮后完成
*****策略提升完成*****
[0.25, 0.25, 0.25, 0.25]
策略评估进行44轮后完成
*****策略提升完成*****
[0, 0, 0, 1.0]
策略评估进行12轮后完成
*****策略提升完成*****
[0, 0.5, 0, 0.5]
策略评估进行1轮后完成
*****策略提升完成*****
[0, 0.5, 0, 0.5]
状态价值：
-7.7123 -7.4581 -7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 
-7.4581 -7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 -1.9000 
-7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 -1.9000 -1.0000 
-7.4581 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


#### Value Iteration

价值迭代直接使用贝尔曼最优方程来进行动态规划，得到最终的最优状态价值(只维护状态价值函数)

只在策略评估中进行一轮价值更新，然后直接根据更新后的价值进行策略提升

In [156]:
class ValueIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow
        self.theta = theta
        self.gamma = gamma
        ## 价值迭代结束后得到的策略
        self.pi = [None for i in range(self.env.ncol * self.env.nrow)]
    
    def value_iteration(self):
        cnt = 0
        while True:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p, next_s, r, done = self.env.P[s][a]
                    qsa_list.append(r + self.gamma * p * self.v[next_s] * (1 - done))
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print("价值迭代一共进行{:d}轮".format(cnt))
        self.get_policy()
    
    def get_policy(self): # 根据价值函数导出一个贪婪策略
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                p, next_s, r, done = self.env.P[s][a]
                qsa_list.append(r + self.gamma * p * self.v[next_s] * (1 - done))
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            ## 均分概率
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]

env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = ValueIteration(env, theta, gamma)
agent.value_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])
        

价值迭代一共进行14轮
状态价值：
-7.7123 -7.4581 -7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 
-7.4581 -7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 -1.9000 
-7.1757 -6.8619 -6.5132 -6.1258 -5.6953 -5.2170 -4.6856 -4.0951 -3.4390 -2.7100 -1.9000 -1.0000 
-7.4581 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 +0.0000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 
