In [1]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# Policy Iteration

In [61]:
def PolicyIteration():        
    def PolicyEvaluation(policy, gamma=1.0):
        '''Iteratively evaluate the value-function under certain policy
        Alternatively, we can formulate a set of linear equations in terms of v[s]
        and solve them to find the value fxn
        '''
        value_fxn = np.zeros(env.nS)    
        
        while True:
            change = 0
            for s in range(env.nS):
                v = value_fxn[s]
                
                # Expected update
                expectation = 0
                a = policy[s]
                # Notice we never use terminal
                for prob, next_state, reward, terminal in env.P[s][a]:
                    expectation += prob * (reward + gamma * value_fxn[next_state])
                value_fxn[s] = expectation
                change = max(change, np.abs(v - value_fxn[s]))
                
            # a small positive number determinnig the accuracy of estimation
            if change < 1e-10:
                break
        return value_fxn
        
    def PolicyImprovement(value_fxn, gamma=1.0):
        policy = np.zeros(env.nS)
        for s in range(env.nS):
            action_expectations = np.zeros(env.nA)
            for a in range(env.nA):
                action_expectations[a] = sum([prob*(reward + gamma*value_fxn[next_state]) for prob, next_state, reward, terminal in env.P[s][a]])
            policy[s] = np.argmax(action_expectations)
        return policy
            
    env = gym.make('FrozenLake8x8-v0')
    
    # policy
    policy = np.random.choice(env.nA, size=(env.nS))  # initialize a random policy
    
    # number of iterations until convergence
    max_iterations = 2000
    
    gamma = 1.0
    
    for i in range(max_iterations):
        value_fxn = PolicyEvaluation(policy, gamma)
        new_policy = PolicyImprovement(value_fxn, gamma)
        if (np.all(policy == new_policy)):
            print ('Policy-Iteration converged at step %d.' %(i+1))
            break
        policy = new_policy
    return policy
        

In [62]:
optimal_policy = PolicyIteration()

Policy-Iteration converged at step 7.


In [67]:
def run_episode(env, policy, gamma = 1.0, render = True):
    """ Runs an episode and return the total reward """
    obs = env.reset()
    total_reward = 0
    step_idx = 0
    while True:
        if render:
            env.render()
        obs, reward, done , _ = env.step(int(policy[obs]))
        total_reward += (gamma ** step_idx * reward)
        step_idx += 1
        if done:
            break
    return total_reward


def evaluate_policy(env, policy, gamma = 1.0, n = 100):
    scores = [run_episode(env, policy, gamma, False) for _ in range(n)]
    return np.mean(scores)

In [69]:
run_episode(env, optimal_policy)


[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
SFFFFFFF
[41mF[0mFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Up)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
SFFFFFFF
[41mF[0mFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Up)
SFFFFFFF
[41mF[0mFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Up)
[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
  (Down)
S[41mF[0mFFFF

1.0

In [66]:
evaluate_policy(env, optimal_policy)

0.9