# Policy Derivation from Value Functions

This notebook demonstrates how to derive an optimal policy directly from the state-value function (V(s)) in a grid world environment. We start with an arbitrary policy, evaluate it to find its value function, and then improve the policy based on those values.

## 1. Setup

Import necessary functions, create the standard grid world environment, and define an initial, arbitrary policy.

In [1]:
from rlgridworld.standard_grid import create_standard_grid
from rlgridworld.algorithms import iterative_policy_evaluation

gw = create_standard_grid()

# Initial arbitrary policy
policy = { 
    (0,0):'up', (0,1):'right',(0,2):'right',(0,3):'up',
    (1,0):'up', (1,1):'', (1,2):'right', (1,3):'',
    (2,0):'right', (2,1):'right', (2,2):'right', (2,3):''
}

print("Initial Input Policy:")
gw.print_policy(policy)

Initial Input Policy:
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |  Right |        |
-------------------------------------
|     Up |  Right |  Right |     Up |
-------------------------------------


## 2. Initial Policy Evaluation (Gamma = 0.9)

Perform iterative policy evaluation using the initial policy and a discount factor `gamma = 0.9` to compute the state-value function V(s) for this policy.

In [2]:
iterative_policy_evaluation(gw, policy, gamma = 0.9)

print("\nValues for the initial policy (gamma=0.9):")
gw.print_values()


Values for the initial policy (gamma=0.9):
-------------------------------------
|   0.81 |   0.90 |   1.00 |   0.00 |
-------------------------------------
|   0.73 |   0.00 |  -1.00 |   0.00 |
-------------------------------------
|   0.66 |  -0.81 |  -0.90 |  -1.00 |
-------------------------------------


## 3. Function to Compute Policy from Values

Define a function `compute_policy_from_values` that takes the grid world and its current value function as input. For each state, it iterates through possible actions, calculates the expected return for taking each action (using the Bellman equation component: `reward + gamma * value_at_destination`), and selects the action that yields the highest expected return. This creates a new, improved policy based on the current values.

In [3]:
def compute_policy_from_values(gw, gamma = 0.9):
    """Computes the greedy policy based on the current values in the grid world."""
    # create null policy dictionary
    new_policy = {}

    # loop over all states
    for i in range(gw.M):
        for j in range(gw.N):
            state = (i,j)

            # assign 'no' policy to barrier states, there are no actions at barrier states
            if gw.is_barrier(state):
                new_policy[state] = ''
                continue # Skip to next state

            # assign 'no' policy to terminal sttes, there are no actions at terminal states 
            if gw.is_terminal(state):
                new_policy[state] = ''
                continue # Skip to next state

            # for all non terminal and non barrier states
            # set candidate best action and best value
            best_action = None
            best_value = float('-inf')

            # get dictionary of all valid decisions and rewards at current state (i,j)
            dr = gw.valid_decisions_and_rewards(state)

            # iterate over all action, reward in 
            for action, reward in dr.items():
                # get reward for current action (Note: reward in dr.items() might be outdated, recalculate)
                current_reward = gw.get_reward_for_action(state,action)
                
                # get the value of the destination state for the current action
                value_at_dest = gw.get_value_at_destination(state,action)

                # compute candidate value (expected return for taking this action from this state)
                value = current_reward + gamma*value_at_dest

                # if value is better, then update best action and best value
                if value > best_value:
                    best_value = value
                    best_action = action

            # add best action to the policy dictionary 
            new_policy[state] = best_action

    return new_policy

## 4. Derive and Compare Policies (Based on Gamma = 0.9 Values)

Use the `compute_policy_from_values` function (with the default `gamma=0.9`) to derive a new policy based on the values calculated in step 2. Compare this new policy to the original arbitrary policy.

In [4]:
# Compute the policy derived from the V(s) calculated with gamma=0.9
# Note: The function uses gamma=0.9 by default, matching the values just calculated.
new_policy_g09 = compute_policy_from_values(gw, gamma=0.9) 

print("Original Policy:")
gw.print_policy(policy)

print("\nNew Policy (derived from gamma=0.9 values):")
gw.print_policy(new_policy_g09)

Original Policy:
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |  Right |        |
-------------------------------------
|     Up |  Right |  Right |     Up |
-------------------------------------

New Policy (derived from gamma=0.9 values):
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |     Up |        |
-------------------------------------
|     Up |   Left |   Left |   Left |
-------------------------------------


## 5. Policy Evaluation with Different Gamma (Gamma = 0.8)

Now, let's re-evaluate the *original* arbitrary policy, but this time using a different discount factor, `gamma = 0.8`. This will result in a different state-value function.

In [5]:
# Re-evaluate the original policy with gamma = 0.8
iterative_policy_evaluation(gw, policy, gamma = 0.8)

print("\nValues for the initial policy (gamma=0.8):")
gw.print_values()


Values for the initial policy (gamma=0.8):
-------------------------------------
|   0.64 |   0.80 |   1.00 |   0.00 |
-------------------------------------
|   0.51 |   0.00 |  -1.00 |   0.00 |
-------------------------------------
|   0.41 |  -0.64 |  -0.80 |  -1.00 |
-------------------------------------


## 6. Derive and Compare Policies (Based on Gamma = 0.8 Values)

Derive another new policy, this time using the values calculated with `gamma = 0.8`. Make sure to pass `gamma=0.8` to the `compute_policy_from_values` function. Compare this policy to the original one and the one derived using `gamma = 0.9`.

In [6]:
# Compute the policy derived from the V(s) calculated with gamma=0.8
new_policy_g08 = compute_policy_from_values(gw, gamma=0.8)

print("Original Policy:")
gw.print_policy(policy)

print("\nNew Policy (derived from gamma=0.9 values):")
gw.print_policy(new_policy_g09)

print("\nNew Policy (derived from gamma=0.8 values):")
gw.print_policy(new_policy_g08)

Original Policy:
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |  Right |        |
-------------------------------------
|     Up |  Right |  Right |     Up |
-------------------------------------

New Policy (derived from gamma=0.9 values):
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |     Up |        |
-------------------------------------
|     Up |   Left |   Left |   Left |
-------------------------------------

New Policy (derived from gamma=0.8 values):
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |     Up |        |
-------------------------------------
|     Up |   Left |   Left |   Left |
-------------------------------------


## 7. Further Experimentation

You can experiment further by:
*   Trying different values for `gamma` (e.g., 1.0, 0.5, 0.1) and observing how the resulting values and derived policies change.
*   Starting with a different initial policy.
*   Modifying the grid world rewards or structure (using functions from `rlgridworld`).
*   Visualizing the value function or policy (e.g., using heatmaps or arrows).

### Example: Evaluate and Derive Policy for Gamma = 1.0

In [7]:
# Evaluate original policy with gamma = 1.0
iterative_policy_evaluation(gw, policy, gamma = 1.0)
print("\nValues for the initial policy (gamma=1.0):")
gw.print_values()

# Compute the policy derived from the V(s) calculated with gamma=1.0
new_policy_g10 = compute_policy_from_values(gw, gamma=1.0)

print("\nNew Policy (derived from gamma=1.0 values):")
gw.print_policy(new_policy_g10)


Values for the initial policy (gamma=1.0):
-------------------------------------
|   1.00 |   1.00 |   1.00 |   0.00 |
-------------------------------------
|   1.00 |   0.00 |  -1.00 |   0.00 |
-------------------------------------
|   1.00 |  -1.00 |  -1.00 |  -1.00 |
-------------------------------------

New Policy (derived from gamma=1.0 values):
-------------------------------------
|  Right |   Left |   Left |        |
-------------------------------------
|   Down |        |     Up |        |
-------------------------------------
|     Up |   Left |   Left |   Left |
-------------------------------------
