In [1]:
import random

def find_point(value, increment):
    return round(round(value / increment) * increment, 2)

def approximate(raw_state):
    array = list(raw_state)
    return tuple([
        find_point(array[0], 0.2) + 0, 
        find_point(array[1], 0.2) + 0, 
        find_point(array[2], 0.5) + 0,
        find_point(array[3], 0.5) + 0
    ])

def generate(lower_bound, upper_bound, interval):
    output = [lower_bound]
    while lower_bound <= upper_bound:
        lower_bound += interval
        output.append( float( round(lower_bound,1) + 0 ) )
    return output

def generate_states(theta1, theta2, ang_vel_theta1, ang_vel_theta2):
    states = []
    for a in theta1:
        for b in theta2:
            for c in ang_vel_theta1:
                for d in ang_vel_theta2:
                    states.append((a, b, c, d))
    return states


def create_transition_reward_function(states, actions, env):
    table = {}
    for state in states:
        for action in actions:
            env.reset(state=state)
            obs, reward, done, info = env.step(action)
            table[(state, action)] = (reward, approximate(obs))
    return table

In [2]:

def policy_evaluation(state, actions, transition_and_reward_function, policy, value_function, gamma=0.9, max_iterations=10000):
    min_difference = 0.0001
    new_value_function = {}

    
    for i in range(max_iterations):
        for state in states:
            new_val = 0
            for action in actions:
                reward, next_state = transition_and_reward_function[(state, action)]
                next_state_value = value_function[next_state]
                    
                new_val += policy[state][action] * (reward + gamma*next_state_value)
            
            
            new_value_function[state] = new_val
        
        # Check convergence
        differences = [abs(new_value_function[state] - value_function[state])
           for state in states]

        diff = max(differences)
        print(f"policy_evaluation iteration_number: {i+1}. Diff: {diff}")

        if diff < min_difference:
            print(f"Policy Evaluation converged at {i+1}")
            return new_value_function
                

        
        value_function = new_value_function.copy()
            
    print(f"Policy Evaluation did not converge.")
    return value_function


def improve_policy(states, actions, transition_and_reward_function, value_function, gamma=0.9):
    new_policy = {}
    for state in states:
        action_values = {}
        for action in actions:
            reward, next_state = transition_and_reward_function[(state, action)]
            action_values[action] = reward + gamma * value_function[next_state]
        greedy_action, value = max(action_values.items(), key= lambda pair: pair[1])
        
        new_policy[state] = {action:1 if action is greedy_action else 0 for action in actions}
    return new_policy


def policy_iteration(states, actions, transition_and_reward_function, policy, value_function, max_iterations=30):
    new_value_function = {}

    for i in range(max_iterations):
        converged = True
        print("policy iteration_number:",i+1)
        
        # Evaluate the current policy
        new_value_function = policy_evaluation(states, actions, transition_and_reward_function, policy, value_function)
        
        # Improve the policy based on new value function
        new_policy = improve_policy(states, actions, transition_and_reward_function, new_value_function)
        value_function = new_value_function.copy()
        
        # Check if convergence
        for state in states:
            if get_optimal_action(state, policy) != get_optimal_action(state, new_policy):
                converged = False
                break
        
        policy = new_policy.copy()
        
        if converged:
            # we have convergence
            print(f"Policy Iteration converged at {i+1}")
            return policy

    print(f"Policy Iteration did not converge")
    return policy

def get_optimal_action(state, optimal_policy):
    greedy_action, prob = max(optimal_policy[state].items(), key= lambda pair: pair[1])
    return greedy_action


In [3]:
from acrobot_env import *

env = AcrobotEnv()

theta1 = generate(-3.2, 3.2, 0.2)
theta2 = generate(-3.2, 3.2, 0.2)
ang_vel_theta1 = generate(-13, 13, 0.5)
ang_vel_theta2 = generate(-28.5, 28.5, 0.5)

states = generate_states(theta1, theta2, ang_vel_theta1, ang_vel_theta2)
actions = [0, 1, 2]


In [4]:
len(states)

6821496

In [5]:
from joblib import dump, load
load_transition = False

if not load_transition:
    transition_and_reward_function = create_transition_reward_function(states, actions, env)
    dump(transition_and_reward_function, 'transition_and_reward_function.joblib') 
else:
    print("Loading transition and reward function")
    transition_and_reward_function = load('transition_and_reward_function.joblib') 

In [6]:
from joblib import dump, load

starting_policy = {state:{0: 0.33, 1: 0.33, 2: 0.33} for state in states}
value_function = {state:0 for state in states}

load_optimal_policy = False

if not load_optimal_policy:
    optimal_policy = policy_iteration(states, actions, transition_and_reward_function, starting_policy, value_function)
    dump(optimal_policy, 'optimal_policy.joblib') 
else:
    print("Loading optimal policy")
    optimal_policy = load('optimal_policy.joblib') 

policy iteration_number: 1
policy_evaluation iteration_number: 1. Diff: 0.99
policy_evaluation iteration_number: 2. Diff: 0.8820900000000003
policy_evaluation iteration_number: 3. Diff: 0.7859421900000008
policy_evaluation iteration_number: 4. Diff: 0.7002744912900014
policy_evaluation iteration_number: 5. Diff: 0.6239445717393921
policy_evaluation iteration_number: 6. Diff: 0.5559346134197991
policy_evaluation iteration_number: 7. Diff: 0.4953377405570416
policy_evaluation iteration_number: 8. Diff: 0.44134592683632423
policy_evaluation iteration_number: 9. Diff: 0.39323922081116613
policy_evaluation iteration_number: 10. Diff: 0.3503761457427501
policy_evaluation iteration_number: 11. Diff: 0.312185145856791
policy_evaluation iteration_number: 12. Diff: 0.27815696495840037
policy_evaluation iteration_number: 13. Diff: 0.24783785577793527
policy_evaluation iteration_number: 14. Diff: 0.22082352949814066
policy_evaluation iteration_number: 15. Diff: 0.19675376478284434
policy_evaluatio

policy_evaluation iteration_number: 43. Diff: 0.018765167186668208
policy_evaluation iteration_number: 44. Diff: 0.016888650468001032
policy_evaluation iteration_number: 45. Diff: 0.015199785421200929
policy_evaluation iteration_number: 46. Diff: 0.011714625583421068
policy_evaluation iteration_number: 47. Diff: 0.01054316302507985
policy_evaluation iteration_number: 48. Diff: 0.00948884672257222
policy_evaluation iteration_number: 49. Diff: 0.008539962050317484
policy_evaluation iteration_number: 50. Diff: 0.0076859658452850255
policy_evaluation iteration_number: 51. Diff: 0.0069173692607567006
policy_evaluation iteration_number: 52. Diff: 0.006225632334681208
policy_evaluation iteration_number: 53. Diff: 0.005603069101212554
policy_evaluation iteration_number: 54. Diff: 0.0050427621910920095
policy_evaluation iteration_number: 55. Diff: 0.0045384859719828086
policy_evaluation iteration_number: 56. Diff: 0.004084637374784705
policy_evaluation iteration_number: 57. Diff: 0.003676173637

policy_evaluation iteration_number: 73. Diff: 0.0005920911894294889
policy_evaluation iteration_number: 74. Diff: 0.0005328820704866288
policy_evaluation iteration_number: 75. Diff: 0.0004795938634387653
policy_evaluation iteration_number: 76. Diff: 0.00043163447709382297
policy_evaluation iteration_number: 77. Diff: 0.00034740268893518333
policy_evaluation iteration_number: 78. Diff: 0.0003126624200397998
policy_evaluation iteration_number: 79. Diff: 0.00028139617803635275
policy_evaluation iteration_number: 80. Diff: 0.0002532565602333392
policy_evaluation iteration_number: 81. Diff: 0.00022793090420947237
policy_evaluation iteration_number: 82. Diff: 0.00020513781378994622
policy_evaluation iteration_number: 83. Diff: 0.00018462403241059633
policy_evaluation iteration_number: 84. Diff: 0.00016616162916882615
policy_evaluation iteration_number: 85. Diff: 0.00014954546625212117
policy_evaluation iteration_number: 86. Diff: 0.0001345909196270867
policy_evaluation iteration_number: 87. 

policy_evaluation iteration_number: 31. Diff: 0.04945208984110128
policy_evaluation iteration_number: 32. Diff: 0.04450688085699106
policy_evaluation iteration_number: 33. Diff: 0.04005619277129213
policy_evaluation iteration_number: 34. Diff: 0.03605057349416274
policy_evaluation iteration_number: 35. Diff: 0.032445516144746556
policy_evaluation iteration_number: 36. Diff: 0.02920096453027199
policy_evaluation iteration_number: 37. Diff: 0.0262808680772455
policy_evaluation iteration_number: 38. Diff: 0.013302794647307437
policy_evaluation iteration_number: 39. Diff: 0.011972515182576693
policy_evaluation iteration_number: 40. Diff: 0.010775263664319468
policy_evaluation iteration_number: 41. Diff: 0.009286636537709114
policy_evaluation iteration_number: 42. Diff: 0.008357972883938025
policy_evaluation iteration_number: 43. Diff: 0.007522175595543246
policy_evaluation iteration_number: 44. Diff: 0.006769958035989454
policy_evaluation iteration_number: 45. Diff: 0.00609296223239042
pol

policy_evaluation iteration_number: 21. Diff: 0.0678355141853153
policy_evaluation iteration_number: 22. Diff: 0.061051962766784484
policy_evaluation iteration_number: 23. Diff: 0.054946766490106214
policy_evaluation iteration_number: 24. Diff: 0.04945208984109506
policy_evaluation iteration_number: 25. Diff: 0.04450688085698573
policy_evaluation iteration_number: 26. Diff: 0.0400561927712868
policy_evaluation iteration_number: 27. Diff: 0.017359603225285625
policy_evaluation iteration_number: 28. Diff: 0.01562364290275653
policy_evaluation iteration_number: 29. Diff: 0.014061278612480521
policy_evaluation iteration_number: 30. Diff: 0.009112345147009648
policy_evaluation iteration_number: 31. Diff: 0.008201110632309216
policy_evaluation iteration_number: 32. Diff: 0.007380999569077673
policy_evaluation iteration_number: 33. Diff: 0.006642899612169906
policy_evaluation iteration_number: 34. Diff: 0.005978609650953537
policy_evaluation iteration_number: 35. Diff: 0.005380748685858805
po

policy_evaluation iteration_number: 41. Diff: 0.0006494118096016521
policy_evaluation iteration_number: 42. Diff: 0.0005844706286417534
policy_evaluation iteration_number: 43. Diff: 0.000526023565777578
policy_evaluation iteration_number: 44. Diff: 0.00047342120920035313
policy_evaluation iteration_number: 45. Diff: 0.00042607908827996255
policy_evaluation iteration_number: 46. Diff: 0.00038347117945214393
policy_evaluation iteration_number: 47. Diff: 0.000345124061506219
policy_evaluation iteration_number: 48. Diff: 0.0003106116553555083
policy_evaluation iteration_number: 49. Diff: 3.197442310920451e-14
Policy Evaluation converged at 49
policy iteration_number: 10
policy_evaluation iteration_number: 1. Diff: 0.6592351345496112
policy_evaluation iteration_number: 2. Diff: 0.48295450738251144
policy_evaluation iteration_number: 3. Diff: 0.43465905664425986
policy_evaluation iteration_number: 4. Diff: 0.27667940391000023
policy_evaluation iteration_number: 5. Diff: 0.2490114635190004
po

policy_evaluation iteration_number: 5. Diff: 0.05918732518669945
policy_evaluation iteration_number: 6. Diff: 0.0186871860481439
policy_evaluation iteration_number: 7. Diff: 0.016818467443330043
policy_evaluation iteration_number: 8. Diff: 0.0008718184249225658
policy_evaluation iteration_number: 9. Diff: 0.0007846365824306645
policy_evaluation iteration_number: 10. Diff: 0.0007061729241879533
policy_evaluation iteration_number: 11. Diff: 6.397860588887326e-31
Policy Evaluation converged at 11
policy iteration_number: 15
policy_evaluation iteration_number: 1. Diff: 0.12069507658075818
policy_evaluation iteration_number: 2. Diff: 0.06576369465188758
policy_evaluation iteration_number: 3. Diff: 0.015136620698997127
policy_evaluation iteration_number: 4. Diff: 0.013019232517542001
policy_evaluation iteration_number: 5. Diff: 0.004545783693051142
policy_evaluation iteration_number: 6. Diff: 0.004091205323746294
policy_evaluation iteration_number: 7. Diff: 0.0036820847913716648
policy_evalu

In [9]:
num_episodes=10
rewards = []
completed_runs = []
for episode in range(0,num_episodes):
    observation = env.reset()
    total_reward = 0
    completed = False
    for timestep in range(1,3000):
        env.render()
        action = get_optimal_action(approximate(observation), optimal_policy)
        observation, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            print('COMPLETED')
            completed = True
            break
        
    rewards.append(total_reward)
    completed_runs.append(completed)



COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED
COMPLETED


In [8]:
completed_runs

[True, True, True, True, True, True, True, True, True, True]

In [10]:
sum(rewards)/len(rewards)

-495.7