## 策略迭代

策略迭代分为两步:

<img src="../../images/11-policy-iteratioin.png" width="50%">



In [2]:
import numpy as np
import matplotlib.pyplot as plt

class GridWorld(object):
    def __init__(self, m, n, magicSquares):
        self.grid = np.zeros((m,n))
        self.m = m
        self.n = n
        self.stateSpace = [i for i in range(self.m*self.n)]
        self.stateSpace.remove(80)
        self.stateSpacePlus = [i for i in range(self.m*self.n)]
        self.possibleActions = ['U', 'D', 'L', 'R']
        self.actionSpace = {'U': -self.m, 'D': self.m,'L': -1, 'R': 1}
        self.P = {}  # 用一个字典存储状态转移矩阵。
        # dict with magic squares and resulting squares
        self.magicSquares = magicSquares
        self.initP()

    def initP(self):
        for state in self.stateSpace:
            for action in self.possibleActions:
                reward = -1
                state_ = state + self.actionSpace[action]
                if state_ in self.magicSquares.keys():  # 如果state在magicSquares中，就进行跳转。
                    state_ = self.magicSquares[state_]
                if self.offGridMove(state_, state):
                    state_ = state
                    
                if self.isTerminalState(state_):
                    reward = 0
                    
                self.P[(state_, reward, state, action)] = 1

    def isTerminalState(self, state):
        return state in self.stateSpacePlus and state not in self.stateSpace

    def offGridMove(self, newState, oldState):
        # if we move into a row not in the grid
        if newState not in self.stateSpacePlus:
            return True
        # if we're trying to wrap around to next row
        elif oldState % self.m == 0 and newState  % self.m == self.m - 1:
            return True
        elif oldState % self.m == self.m - 1 and newState % self.m == 0:
            return True
        else:
            return False

def printV(V, grid):
    for idx, row in enumerate(grid.grid):
        for idy, _ in enumerate(row):
            state = grid.m * idx + idy
            print('%.2f' % V[state], end='\t')
        print('\n')
    print('--------------------')

def printPolicy(policy, grid):
    for idx, row in enumerate(grid.grid):
        for idy, _ in enumerate(row):
            state = grid.m * idx + idy
            if not grid.isTerminalState(state):
                if state not in grid.magicSquares.keys():
                    print('%s' % policy[state], end='\t')
                else:
                    print('%s' % '--', end='\t')
            else:
                print('%s' % '--', end='\t')
        print('\n')
    print('--------------------')

def evaluatePolicy(grid, V, policy, GAMMA, THETA):
    # policy evaluation for the random choice in gridworld
    converged = False
    i = 0
    while not converged:
        DELTA = 0
        for state in grid.stateSpace:
            i += 1
            oldV = V[state]
            total = 0
            weight = 1 / len(policy[state])
            for action in policy[state]:
                for key in grid.P:
                    (newState, reward, oldState, act) = key
                    # We're given state and action, want new state and reward
                    if oldState == state and act == action:
                        total += weight*grid.P[key]*(reward+GAMMA*V[newState])
            V[state] = total
            DELTA = max(DELTA, np.abs(oldV-V[state]))
            converged = True if DELTA < THETA else False
    print(i, 'sweeps of state space in policy evaluation')
    return V

def improvePolicy(grid, V, policy, GAMMA):
    stable = True
    newPolicy = {}
    i = 0
    for state in grid.stateSpace:
        i += 1
        oldActions = policy[state]
        value = []
        newAction = []
        for action in policy[state]:
            weight = 1 / len(policy[state])
            for key in grid.P:
                (newState, reward, oldState, act) = key
                # We're given state and action, want new state and reward
                if oldState == state and act == action:
                    value.append(np.round(weight*grid.P[key]*(reward+GAMMA*V[newState]), 2))
                    newAction.append(action)
        value = np.array(value)
        best = np.where(value == value.max())[0]
        bestActions = [newAction[item] for item in best]
        newPolicy[state] = bestActions

        if oldActions != bestActions:
            stable = False
    print(i, 'sweeps of state space in policy improvement')
    return stable, newPolicy

if __name__ == '__main__':
    # map magic squares to their connecting square
    magicSquares = {18: 54, 63: 14}
    env = GridWorld(9, 9, magicSquares)
    # model hyperparameters
    GAMMA = 1.0
    THETA = 1e-6 # convergence criteria

    V = {}
    for state in env.stateSpacePlus:
        V[state] = 0

    policy = {}
    for state in env.stateSpace:
        # equiprobable random strategy
        policy[state] = env.possibleActions

    V = evaluatePolicy(env, V, policy, GAMMA, THETA)
    printV(V, env)

    stable = False
    while not stable:
        V = evaluatePolicy(env, V, policy, GAMMA, THETA)

        stable, policy = improvePolicy(env, V, policy, GAMMA)

    printV(V, env)

    printPolicy(policy, env)

315200 sweeps of state space in policy evaluation
-536.69	-537.05	-536.40	-534.36	-531.20	-527.49	-523.95	-521.19	-519.69	

-532.33	-534.06	-533.80	-531.48	-527.73	-523.34	-519.15	-515.93	-514.19	

-530.56	-529.07	-529.25	-526.04	-520.91	-514.99	-509.40	-505.17	-502.96	

-526.27	-526.72	-524.08	-518.51	-510.86	-502.31	-494.30	-488.38	-485.53	

-525.86	-523.46	-517.86	-509.07	-497.73	-485.08	-473.13	-464.52	-461.24	

-523.83	-519.41	-510.81	-498.17	-481.90	-463.16	-444.60	-431.32	-429.67	

-522.24	-515.53	-503.81	-486.91	-464.55	-437.05	-406.78	-382.51	-392.46	

-519.18	-512.68	-497.98	-477.09	-448.34	-409.72	-358.96	-295.47	-230.31	

-518.61	-509.88	-494.33	-471.14	-438.01	-390.53	-319.87	-206.12	0.00	

--------------------
80 sweeps of state space in policy evaluation
80 sweeps of state space in policy improvement
1200 sweeps of state space in policy evaluation
80 sweeps of state space in policy improvement
-11.00	-12.00	-13.00	-12.00	-11.00	-10.00	-9.00	-8.00	-7.00	

-10.00	-11.00	-1