[View in Colaboratory](https://colab.research.google.com/github/yingjie-wu/Reinforcement-Learning-StarAI/blob/master/Lesson_2_Exercise_2_Policy_Iteration.ipynb)

In [0]:
          _____                _____                    _____                    _____                    _____                    _____          
         /\    \              /\    \                  /\    \                  /\    \                  /\    \                  /\    \         
        /::\    \            /::\    \                /::\    \                /::\    \                /::\    \                /::\    \        
       /::::\    \           \:::\    \              /::::\    \              /::::\    \              /::::\    \               \:::\    \       
      /::::::\    \           \:::\    \            /::::::\    \            /::::::\    \            /::::::\    \               \:::\    \      
     /:::/\:::\    \           \:::\    \          /:::/\:::\    \          /:::/\:::\    \          /:::/\:::\    \               \:::\    \     
    /:::/__\:::\    \           \:::\    \        /:::/__\:::\    \        /:::/__\:::\    \        /:::/__\:::\    \               \:::\    \    
    \:::\   \:::\    \          /::::\    \      /::::\   \:::\    \      /::::\   \:::\    \      /::::\   \:::\    \              /::::\    \   
  ___\:::\   \:::\    \        /::::::\    \    /::::::\   \:::\    \    /::::::\   \:::\    \    /::::::\   \:::\    \    ____    /::::::\    \  
 /\   \:::\   \:::\    \      /:::/\:::\    \  /:::/\:::\   \:::\    \  /:::/\:::\   \:::\____\  /:::/\:::\   \:::\    \  /\   \  /:::/\:::\    \ 
/::\   \:::\   \:::\____\    /:::/  \:::\____\/:::/  \:::\   \:::\____\/:::/  \:::\   \:::|    |/:::/  \:::\   \:::\____\/::\   \/:::/  \:::\____\
\:::\   \:::\   \::/    /   /:::/    \::/    /\::/    \:::\  /:::/    /\::/   |::::\  /:::|____|\::/    \:::\  /:::/    /\:::\  /:::/    \::/    /
 \:::\   \:::\   \/____/   /:::/    / \/____/  \/____/ \:::\/:::/    /  \/____|:::::\/:::/    /  \/____/ \:::\/:::/    /  \:::\/:::/    / \/____/ 
  \:::\   \:::\    \      /:::/    /                    \::::::/    /         |:::::::::/    /            \::::::/    /    \::::::/    /          
   \:::\   \:::\____\    /:::/    /                      \::::/    /          |::|\::::/    /              \::::/    /      \::::/____/           
    \:::\  /:::/    /    \::/    /                       /:::/    /           |::| \::/____/               /:::/    /        \:::\    \           
     \:::\/:::/    /      \/____/                       /:::/    /            |::|  ~|                    /:::/    /          \:::\    \          
      \::::::/    /                                    /:::/    /             |::|   |                   /:::/    /            \:::\    \         
       \::::/    /                                    /:::/    /              \::|   |                  /:::/    /              \:::\____\        
        \::/    /                                     \::/    /                \:|   |                  \::/    /                \::/    /        
         \/____/                                       \/____/                  \|___|                   \/____/                  \/____/         

# Policy Iteration Exercise

## Task

Implement Policy iteration algorithm (you may use Policy Evaluation implementation from exercise 1)

**Steps1**

- Evaluate given policy. 
- Once optimal state-value function is calculated for the given policy update the policy by acting greedily with respect to state value function. The goal is to exclude "bad" actions from being executed.
- After the policy has been update, evaluate and update again until the policiy is optimal.

*Note:* you may consider policy to be optimal once it stops being updated after policy evaluation step.

## What we implement

$$
V_*(s) = max(R(s, a) + \gamma \sum_{s' \in S} P^a_{ss'} V_*(s'))
$$

## Implementation

In [1]:
# !python -m pip install -e git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
# !python -m pip install gym
!pip install -e git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
!pip install gym

Obtaining rlenvs from git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
  Updating ./src/rlenvs clone
Installing collected packages: rlenvs
  Found existing installation: rlenvs 0.1
    Uninstalling rlenvs-0.1:
      Successfully uninstalled rlenvs-0.1
  Running setup.py develop for rlenvs
Successfully installed rlenvs


In [0]:
from IPython.core.debugger import set_trace
import numpy as np
import pprint

# Import below can all of a sudden break
# NOTE: if running locally, remove src.rlenvs from import
# from src.rlenvs.rlenvs.envs.gridworld import GridworldEnv
from rlenvs.envs.gridworld import GridworldEnv

In [0]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [0]:
def calculate_state_value(policy, state, env, V, discount_factor):
    """"
    Calculate state value given policy, state, and current state value function.

    Args:
      policy: Policy - [S, A] matrix of probabilities of action A given state S
      env: Environment.
        env.P[s][a] return list of transition tuples (transition_probability, 
          next_state, reward, done).
      V: current state value function, V[s] return value for state s.
    """
    v = 0
    # Look at the possible next actions
    for a, action_prob in enumerate(policy[state]):
        # For each action, look at the possible next states...
        for prob, next_state, reward, done in env.P[state][a]:
            # Calculate the expected value
            v += action_prob * prob * (reward + discount_factor * V[next_state])
    return v
  

def run_full_sweep(policy, env, V, discount_factor):
    """
    Run a full sweep over states.
    """
    new_V = np.zeros(env.nS)
    delta = 0
    # For each state, perform a "full backup"
    for s in range(env.nS):
        v = calculate_state_value(policy, s, env, V, discount_factor)

        # How much our value function changed (across any states)
        delta = max(delta, np.abs(v - V[s]))
        new_V[s] = v
    return new_V, delta


def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluate a policy given an environment and a full description of the environment's dynamics.
    
    Args:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
            env.nS is a number of states in the environment. 
            env.nA is a number of actions in the environment.
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: Gamma discount factor.
    
    Returns:
        Vector of length env.nS representing the value function.
    """
    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    # Taken from Policy Evaluation Exercise!
    while True:
        V, delta = run_full_sweep(policy, env, V, discount_factor)
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break    
    return np.array(V)

In [0]:
def policy_update(policy, env, V):
    for state in range(env.nS):
      next_state_value = []
      
      # Get state value of all possible next states
      for a, action_prob in enumerate(policy[state]):
          # For each action, look at the possible next states...
          #print("action:",a,"action_prob:",action_prob)
          for prob, next_state, reward, done in env.P[state][a]:
              #print("prob:",prob,"next_state:",next_state,"reward:",reward,"done:",done)
              next_state_value.append(V[next_state])
              
      greedy_action = np.argmax(next_state_value)
      new_policy_for_state = [0,0,0,0]
      new_policy_for_state[greedy_action] = 1
      policy[state] = new_policy_for_state
    return policy
  

def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    """
    Policy Improvement Algorithm. Iteratively evaluates and improves a policy
    until an optimal policy is found.
    
    Args:
        env: The OpenAI envrionment.
        policy_eval_fn: Policy Evaluation function that takes 3 arguments:
            policy, env, discount_factor.
        discount_factor: gamma discount factor.
        
    Returns:
        A tuple (policy, V). 
        policy is the optimal policy, a matrix of shape [S, A] where each state s
        contains a valid probability distribution over actions.
        V is the value function for the optimal policy.
        
    """
    # Start with a random policy
    policy = np.ones([env.nS, env.nA]) / env.nA
    
    i = 0
    while True:       
        old_policy = policy
        V = policy_eval(policy, env)
        policy = policy_update(policy, env, V)
        
        i+= 1
        print("iteration: ",i)
        if np.allclose(old_policy, policy) or i == 10:
          break
    
    return policy, np.array(V)

In [24]:
policy, v = policy_improvement(env)
print("Policy Probability Distribution:")
print(policy)
print("")

print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

print("Value Function:")
print(v)
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")

('iteration: ', 1)
Policy Probability Distribution:
[[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 2 2]
 [0 0 1 2]
 [0 1 1 0]]

Value Function:
[  0.         -13.99989315 -19.99984167 -21.99982282 -13.99989315
 -17.99986052 -19.99984273 -19.99984167 -19.99984167 -19.99984273
 -17.99986052 -13.99989315 -21.99982282 -19.99984167 -13.99989315
   0.        ]

Reshaped Grid Value Function:
[[  0.         -13.99989315 -19.99984167 -21.99982282]
 [-13.99989315 -17.99986052 -19.99984273 -19.99984167]
 [-19.99984167 -19.99984273 -17.99986052 -13.99989315]
 [-21.99982282 -19.99984167 -13.99989315   0.        ]]

