# <center> Reinforcement Learning Fundamentals</center>

#### Import dependencies

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import pickle

import sys
import os

from RLF_Support.helper import *
from ece4078.gym_simple_gridworlds.envs.grid_env import GridEnv
from ece4078.gym_simple_gridworlds.envs.grid_2dplot import *

from IPython.display import display, HTML

# Activity 1. Iterative Policy Evaluation

Recall the definition of the iterative policy evaluation algorithm

![IterativePolicyEvaluation.png](https://i.postimg.cc/MGbJ9TdV/Iterative-Policy-Evaluation.png)

Let's now compute the value function of the same policy $\pi$ that we saw in the last session

![example_policy.png](https://i.postimg.cc/pLjHnkj0/example-policy.png)

We consider a grid world environment with the following attributes:
- Discount factor $\gamma = 0.9$ (class attribute ``gamma=0.9``)
- Stochastic transition matrix (class attribute ``noise=0.2``)
- A non-zero living cost and big rewards are obtained at terminal states (class attribute ``living_reward=-0.04``)

We have defined the helper function ``encode_policy()`` to encode the policy $\pi$ shown in the image above. The return variable ``policy_pi`` is a dictionary of dictionaries, where each element corresponds to the probability of selecting an action $a$ at a given state $s$

Keep in mind that each action is represented by a number. Action (Up) is represented by 0, (Down) by 1, (Left) by 2 and, finally, (Right) by 3.

In [None]:
grid_world = GridEnv(gamma=0.9, noise=0.2, living_reward=-0.04)
policy_pi = encode_policy(grid_world)

print("Action probabilities at state 0 are:\n{}".format(policy_pi[0]))

Given the policy $\pi$, let's now compute its state-value function using iterative policy evaluation.

**TODO** (Flux Quiz): 
Complete the computation of value function update for each state. We have decomposed this computation into 2 steps:

1. Compute discounted sum of state values of all successor states: $v_{\text{discounted}} = \gamma\sum_{s' \in \mathcal{S}}\mathcal{T}(s,a,s')v(s')$ for each action


2. Compute expectation over all actions: $\sum_{a \in \mathcal{A}}\pi(a|s)(\mathcal{R}(s,a) + v_{\text{discounted}})$ 


**Keep in Mind**: Correspondance between the mathematical notation and implemented code (in the `policy_evaluation` function)

|                         |                                                    |                 |
| ----------------------- | -------------------------------------------------- | --------------- |
|                         | **Variable/Attribute**                             | **Type**        | 
| $\gamma$                | `grid_env.gamma`                                   | `float`         |
| $\mathcal{T}(s, a, s')$ | `grid_env.state_transitions[idx_s, idx_a, idx_s]`  | `numpy` 3d-array| 
| $\mathcal{R}(s, a)$     | `grid_env.rewards[idx_s, idx_a]`                   | `numpy` 2d-array| 
| $\pi(a\vert s)$         | `policy[idx_s][idx_a]`                          | `dict` of `dict`| 
| $v_\pi(s)$              | `v[idx_s]`                                         | `dict`          | 
| $V$                     | `old_v`                                            | `float`         |
| $v_{discounted}$        | `discounted_v`                                     | `float`         |


In [None]:
def policy_evaluation(grid_env, policy, plot=False, threshold=0.00001):
    
    """
    This function computes the value function for a policy pi in a given environment grid_env.
    
    :param grid_env (GridEnv): MDP environment
    :param policy (dict - stochastic form): Policy being evaluated
    :return: (dict) State-values for all non-terminal states
    """
    
    theta = threshold
    delta = 1000    
    # Initialize v(s) = 0 for all s in S
    v = {s: 0.0 for s in grid_env.get_states()}
    

    while delta > theta:
        delta = 0.0
        # For all states
        for s in v.keys():

            old_v = v[s] # V <- v(s)


            new_v = 0
            # For all actions
            for a, probability_a in policy[s].items():
                discounted_v = 0

                # For all states that are reachable from s with action a
                for s_next in grid_env.get_states():
                    #TODO 1: Compute discounted sum of state values for all successor states ---------
                    discounted_v += 0
                    #ENDTODO -------------------------------------------------------------------------
                    
                #TODO 2: Compute expectation over all actions ------------------------------------
                new_v += 0
                #ENDTODO -------------------------------------------------------------------------
            v[s] = new_v

            
            delta = max(delta, np.abs(old_v - new_v))

    if plot:
        plot_value_function(grid_env, v)
        
    return v
        
        
# Call the policy evalution function
v = policy_evaluation(grid_world, policy_pi, plot=True)
print(v)

# Activity 2. Policy Iteration

Recall the definition of the policy iteration algorithm

![PolicyIteration.png](https://i.postimg.cc/26kRMDKJ/Policy-Iteration.png)

Starting with a random policy, let's find the optimal policy for a grid world environment with attributes:

We consider a grid world environment with the following attributes:
- Discount factor $\gamma = 0.9$ (class attribute ``gamma=0.9``)
- Stochastic transition matrix (class attribute ``noise=0.2``)
- Rewards are only obtained at terminal states (class attribute ``living_reward=-0.04``)

In `numpy`, there is a function called `argmax` where it returns the index of maximum value, for example `np.argmax([1,2,4,3])` will return `2`

Let's now define the policy iteration core algorithm.

**TODO** (Flux Quiz): Complete the main steps of the policy iteration algoritm.
- Use ``policy_evaluation(.)`` in previous code block to compute the state-value function of a given policy
- Implement ``update_policy(.)`` this function compute an optimal policy given the state-value function $v$ by taking the action $a$ with the highest action-value. 

**Keep in mind:** $q_a = \mathcal{R}(s,a) + \gamma\sum_{s^\prime\in\mathcal{S}}\mathcal{T}(s, a, s^\prime)v_{\pi}(s^{\prime})$

In [None]:
def update_policy(grid_env, value_function_):
    """
    This function will update the input policy to the new policy with the input value_function
    
    : param grid_env (GridEnv): MDP environment
    : param value_function_ (dict): the value function that is used to generate policy
    : return (2D array): Updated Policy
    """
    non_terminal_states = grid_env.get_states(exclude_terminal=True)
    new_policy = np.full(grid_env.grid.shape, np.nan)
    for s in non_terminal_states:
        optimal_a = None
        list_of_q_a = []
        for a in grid_env.get_actions():
            discounted_value = 0
            #TODO 3: Please look at policy_evaluation function and think which code can be reused to calculate the discounted value here ===========
            q_a = 0
            #ENDTODO ===============================================================================================================================
            list_of_q_a.append(q_a)

        #TODO 3: Find the max of list of q_a, remember the index of the max value will also correspond to the action code (0 for up, etc.) ===========
        optimal_a = 0
        #ENDTODO ===================================================================================================================================
        
        x, y = np.argwhere(grid_env.grid == s)[0]
        new_policy[x, y] = optimal_a
        
    return new_policy

In [None]:
def policy_iteration(grid_env, policy, plot=False):
    """
    This function iteratively updates a given policy pi for a given environment grid_env until convergence to optimal policy
    
    :param grid_env (GridEnv): MDP environment
    :param policy (matrix from): Deteministic policy being updated
    :return: (dict) State-values for all non-terminal states
    """
    prev_policy = np.zeros(policy.shape)
    
    while not np.array_equal(policy, prev_policy, equal_nan=True):
        
        # Encode policy. This policy representation is needed for policy evaluation
        encoded_policy = encode_policy(grid_env, policy)
        # Set prev_policy to current policy
        prev_policy = policy.copy()
        
        #TODO 4: Complete the remaining steps---------------------------------------------
        # 1. Evaluate the given policy (policy_evaluation expects an
        #    mdp and the enconded_policy as arguments)
        
        # 2. Update policy using function update_policy
        
        #ENDTODO -------------------------------------------------------------------------
        
    if plot:
        plot_policy(grid_env, policy)
    
    return policy


In [None]:
# Create a grid world mdp
grid_world = GridEnv(gamma=0.9, noise=0.2, living_reward=-0.04)

# Generate an initial random policy
initial_policy = define_random_policy(grid_world)

# Compute optimal policy using policy iteration
optimal_policy = policy_iteration(grid_world, initial_policy, plot=True)