In [1]:
import numpy as np
import random

In [2]:
def one_step_lookahead(s,V, p_h = 0.5):
    rewards = np.zeros(101)
    rewards[100] = 1
    A = np.zeros(101)
    stakes = range(1,min(s, 100-s)+1)
    for a in stakes:
        A[a] = p_h * (rewards[s+a] + V[s+a]) + (1-p_h) * (rewards[s-a] + V[s-a])
    return A

In [3]:
def evaluate_policy(policy, theta = 0.001, max_backups=1000):
    old_values = np.zeros(101)
    for i in range(max_backups):
        delta = 0
        new_values = np.zeros(101)
        for s in range(1,100):
            action_values = one_step_lookahead(s,old_values)
            new_values[s] = action_values[int(policy[s])]
        if np.max(np.abs(new_values-old_values)) < theta:
            break
            
        old_values = new_values
    return new_values

In [4]:
def greedy_policy(value_function):
    policy = np.zeros(101)
    for s in range(1,100):
        action_values = one_step_lookahead(s,value_function)
        policy[s] = np.argmax(action_values)
    print("Policy this iteration is \n") 
    print(policy)
    return policy

In [6]:
def policy_iteration():
    old_policy = np.zeros(101)
    for i in range(120):
        value_function = evaluate_policy(old_policy)
        #print(value_function)
        new_policy = greedy_policy(value_function)
        if np.array_equal(old_policy,new_policy):
            break
        old_policy = new_policy
    return old_policy

In [7]:
policy_iteration()

Policy this iteration is 

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. 50. 49. 48. 47.
 46. 45. 44. 43. 42. 41. 40. 39. 38. 37. 36. 35. 34. 33. 32. 31. 30. 29.
 28. 27. 26. 25. 24. 23. 22. 21. 20. 19. 18. 17. 16. 15. 14. 13. 12. 11.
 10.  9.  8.  7.  6.  5.  4.  3.  2.  1.  0.]
Policy this iteration is 

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0. 25. 24. 23. 22. 21. 20. 19. 18. 17. 16. 15.
 14. 13. 37. 36. 35. 34. 33. 32. 44. 43. 42. 47. 46. 48. 50.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1. 12. 11. 10.  9.  8.  7. 19. 18. 17.
 22. 21. 23. 25.  1.  1.  1.  1.  1.  1.  6.  5.  4.  9. 11. 12. 12.  1.
  1.  3.  2.  4.  6.  1.  1.  3.  1.  1.  0.]
Policy this iteration is 

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. 12. 11. 10.  9.  8.
  7. 19. 18. 17. 22. 21.

array([ 0.,  1.,  2.,  3.,  3.,  5.,  6.,  7.,  1.,  9.,  7., 11.,  5.,
        4.,  7.,  6.,  7.,  8.,  7.,  2.,  3.,  4.,  1.,  2.,  1., 25.,
        3.,  2.,  3.,  6.,  5.,  6.,  7.,  8., 16., 18., 13., 13., 15.,
       14., 17.,  9.,  8.,  7., 23., 22., 21., 24., 27., 26., 25., 28.,
       23., 22.,  4.,  5.,  6.,  7., 17.,  9.,  7.,  4.,  5.,  1., 11.,
       35.,  9., 33.,  7.,  2.,  5.,  4.,  3.,  2.,  1., 25., 11.,  2.,
        3.,  4., 20.,  6., 18., 16., 16.,  2., 11., 12., 12.,  7.,  8.,
        9.,  8.,  6.,  2.,  3.,  4.,  1.,  2.,  1.,  0.])