In [0]:
from __future__ import print_function, division
from builtins import range

import numpy as np
from Grid import standard_grid
from Utility import print_values, print_policy


Theta = 1e-3
Gamma = 0.9
ALL_POSSIBLE_ACTIONS = ('U', 'D', 'L', 'R')

def best_action_value(grid,V,s):
  best_a = None
  best_value = float('-inf')
  grid.set_state(s)
  
  for a in ALL_POSSIBLE_ACTIONS:
    trainsitions = grid.get_transition_probs(a)
    expected_v = 0
    expected_r = 0
    for(prob,r,state_prime) in transitions:
      expected_v += prob * r
      expected_r += prob * V[state_prime]
    V = expected_r + Gamma * expected_v
    if V > best_value:
      best_value = V
      Best_a = a
  return best_a,best_value

def calculate_values(grid):
  
  v = {}
  states = grid.all_states()
  
  for s in states:
    v[s] = 0
    while True:
      biggest_change = 0
      for s in grid.non_terminal_states():
        old_v = v[s]
        _,new_v = best_action_value(grid,V,s)
        v[s] = new_v
        biggest_change = max(biggest_change,np.abs(old_v - new_v))
      
      if biggest_change < Theta:
        break
    return V
  
def initialize_random_policy():
  
  policy = {}
  
  for s in grid.non_terminal_states:
    policy[s] = np.random.choice(ALL_POSSIBLE_ACTIONS)
  return policy

def cal_greedy_policy(grid,V):
  policy = initialize_random_policy()
  for s in policy.keys():
    grid.set_states(s)
    best_a,_ = best_action_value(grid,V,s)
    policy[s] = best_a
  return policy


if __name__ == '__main__':
  # this grid gives you a reward of -0.1 for every non-terminal state
  # we want to see if this will encourage finding a shorter path to the goal
  grid = standard_grid(obey_prob=0.8, step_cost=None)

  # print rewards
  print("rewards:")
  print_values(grid.rewards, grid)

  # calculate accurate values for each square
  V = calculate_values(grid)

  # calculate the optimum policy based on our values
  policy = calculate_greedy_policy(grid, V)

  # our goal here is to verify that we get the same answer as with policy iteration
  print("values : ")
  print_values(V, grid)
  print("policy : ")
  print_policy(policy, grid)

  
  
        
      