In [1]:
import random

def find_point(value, increment):
    return round(round(value / increment) * increment, 2)

def approximate(raw_state):
    array = list(raw_state)
    return tuple([
        find_point(array[0], 0.2) + 0, 
        find_point(array[1], 0.2) + 0, 
        find_point(array[2], 1) + 0,
        find_point(array[3], 2) + 0
    ])

def generate(lower_bound, upper_bound, interval):
    output = [lower_bound]
    while lower_bound <= upper_bound:
        lower_bound += interval
        output.append( float( round(lower_bound,1) + 0 ) )
    return output

def generate_states(theta1, theta2, ang_vel_theta1, ang_vel_theta2):
    states = []
    for a in theta1:
        for b in theta2:
            for c in ang_vel_theta1:
                for d in ang_vel_theta2:
                    states.append((a, b, c, d))
    return states


def create_transition_reward_function(states, actions, env):
    table = {}
    for state in states:
        for action in actions:
            env.reset(state=state)
            obs, reward, done, info = env.step(action)
            table[(state, action)] = {'reward':reward, 'next_state':approximate(obs)}
    return table

In [26]:

def policy_evaluation(state, actions, transition_and_reward_function, policy, value_function, gamma=0.9, max_iterations=100):
    min_difference = 0.01
    new_value_function = {}
    first_iter = True
    
    for i in range(max_iterations):
        for state in states:
            new_val = 0
            for action in actions:
                reward, next_state = transition_and_reward_function[(state, action)].values()
                next_state_value = value_function[next_state]
                    
                new_val += policy[state][action] * (reward + gamma*next_state_value)
            
            
            new_value_function[state] = new_val
        
        # Check convergence
        if not first_iter:
            differences = [abs(new_value_function[state] - value_function[state])
               for state in states]
            
            diff = max(differences)
            print(f"policy_evaluation iteration_number: {i}. Diff: {diff}")
            
            if diff < min_difference:
                print(f"Policy Evaluation converged at {i}")
                return new_value_function
                
        else:
            first_iter = False
        
        value_function = new_value_function.copy()
            
    print(f"Policy Evaluation did not converge.")
    return value_function


def improve_policy(states, actions, transition_and_reward_function, value_function, gamma=0.9):
    new_policy = {}
    for state in states:
        action_values = {}
        for action in actions:
            reward, next_state = transition_and_reward_function[(state, action)].values()
            action_values[action] = reward + gamma * value_function[next_state]
        greedy_action, value = max(action_values.items(), key= lambda pair: pair[1])
        
        new_policy[state] = {action:1 if action is greedy_action else 0 for action in actions}
    return new_policy


def policy_iteration(states, actions, transition_and_reward_function, policy, value_function, max_iterations=200):
    new_value_function = {}

    for i in range(max_iterations):
        converged = True
        print("policy iteration_number:",i)
        
        # Evaluate the current policy
        new_value_function = policy_evaluation(states, actions, transition_and_reward_function, policy, value_function)
        
        # Improve the policy based on new value function
        new_policy = improve_policy(states, actions, transition_and_reward_function, new_value_function)
        value_function = new_value_function.copy()
        
        # Check if convergence
        for state in states:
            if get_optimal_action(state, policy) != get_optimal_action(state, new_policy):
                policy = new_policy.copy()
                converged = False
                break
        
        if converged:
            # we have convergence
            print(f"Policy Iteration converged at {i}")
            return new_policy

    print(f"Policy Iteration did not converge")
    return policy

def get_optimal_action(state, optimal_policy):
    greedy_action, prob = max(optimal_policy[state].items(), key= lambda pair: pair[1])
    return greedy_action


In [3]:
from acrobot_env import *

env = AcrobotEnv()

theta1 = generate(-6.2, 6.2, 0.2)
theta2 = generate(-6.2, 6.2, 0.2)
ang_vel_theta1 = generate(-13, 13, 1)
ang_vel_theta2 = generate(-28, 28, 2)

states = generate_states(theta1, theta2, ang_vel_theta1, ang_vel_theta2)
actions = [0, 1, 2]


In [4]:
len(states)

3333960

In [5]:
from joblib import dump, load
load_transition = False

if not load_transition:
    transition_and_reward_function = create_transition_reward_function(states, actions, env)
    dump(transition_and_reward_function, 'transition_and_reward_function.joblib') 
else:
    print("Loading transition and reward function")
    transition_and_reward_function = load('transition_and_reward_function.joblib') 

In [None]:
starting_policy = {state:{0: 0.33, 1: 0.33, 2: 0.33} for state in states}
value_function = {state:0 for state in states}

optimal_policy = policy_iteration(states, actions, transition_and_reward_function, starting_policy, value_function)

policy iteration_number: 0
policy_evaluation iteration_number: 1. Diff: 0.8820900000000003
policy_evaluation iteration_number: 2. Diff: 0.7859421900000008
policy_evaluation iteration_number: 3. Diff: 0.7002744912900014
policy_evaluation iteration_number: 4. Diff: 0.6239445717393917
policy_evaluation iteration_number: 5. Diff: 0.5559346134197991
policy_evaluation iteration_number: 6. Diff: 0.4953377405570425
policy_evaluation iteration_number: 7. Diff: 0.44134592683632423
policy_evaluation iteration_number: 8. Diff: 0.39323922081116613
policy_evaluation iteration_number: 9. Diff: 0.3503761457427501
policy_evaluation iteration_number: 10. Diff: 0.312185145856791
policy_evaluation iteration_number: 11. Diff: 0.27815696495840214
policy_evaluation iteration_number: 12. Diff: 0.24783785577793616
policy_evaluation iteration_number: 13. Diff: 0.22082352949814155
policy_evaluation iteration_number: 14. Diff: 0.19675376478284434
policy_evaluation iteration_number: 15. Diff: 0.17530760442151472
p

policy_evaluation iteration_number: 36. Diff: 0.0751512104440053
policy_evaluation iteration_number: 37. Diff: 0.06763608939960442
policy_evaluation iteration_number: 38. Diff: 0.06087248045964344
policy_evaluation iteration_number: 39. Diff: 0.054785232413678386
policy_evaluation iteration_number: 40. Diff: 0.049306709172310725
policy_evaluation iteration_number: 41. Diff: 0.04437603825508063
policy_evaluation iteration_number: 42. Diff: 0.03993843442957257
policy_evaluation iteration_number: 43. Diff: 0.035944590986614955
policy_evaluation iteration_number: 44. Diff: 0.03235013188795399
policy_evaluation iteration_number: 45. Diff: 0.02911511869915895
policy_evaluation iteration_number: 46. Diff: 0.026203606829243498
policy_evaluation iteration_number: 47. Diff: 0.02358324614631968
policy_evaluation iteration_number: 48. Diff: 0.021224921531688423
policy_evaluation iteration_number: 49. Diff: 0.01910242937851958
policy_evaluation iteration_number: 50. Diff: 0.01719218644066789
policy

policy_evaluation iteration_number: 15. Diff: 0.20738486297511916
policy_evaluation iteration_number: 16. Diff: 0.14145605489976898
policy_evaluation iteration_number: 17. Diff: 0.12731044940979253
policy_evaluation iteration_number: 18. Diff: 0.1145794044688131
policy_evaluation iteration_number: 19. Diff: 0.10312146402193179
policy_evaluation iteration_number: 20. Diff: 0.09280931761973843
policy_evaluation iteration_number: 21. Diff: 0.05368210621575198
policy_evaluation iteration_number: 22. Diff: 0.04831389559417687
policy_evaluation iteration_number: 23. Diff: 0.043482506034759894
policy_evaluation iteration_number: 24. Diff: 0.03913425543128213
policy_evaluation iteration_number: 25. Diff: 0.03522082988815356
policy_evaluation iteration_number: 26. Diff: 0.03169874689933838
policy_evaluation iteration_number: 27. Diff: 0.02852887220940481
policy_evaluation iteration_number: 28. Diff: 0.025675984988464684
policy_evaluation iteration_number: 29. Diff: 0.02310838648961777
policy_ev

In [None]:
from joblib import dump

dump(optimal_policy, 'optimal_policy.joblib') 

In [9]:
num_episodes=1

for episode in range(0,num_episodes):
    observation = env.reset()
    for timestep in range(1,3000):
        env.render()
        action = get_optimal_action(approximate(observation), optimal_policy)
        observation, reward, done, info = env.step(action)

        if done:
            print('COMPLETED')
            break



KeyboardInterrupt: 