In [4]:
import json
import random
from collections import defaultdict

In [None]:
class Policy:
    def select_action(self, state, action):
        pass


class DeterministicPolicy(Policy):
    def update(self, state, action):
        pass


class TabularPolicy(DeterministicPolicy):
    def __init__(self, default_action=None):
        self.policy_table = defaultdict(lambda: default_action)

    def select_action(self, state, actions):
        return self.policy_table[state]

    def update(self, state, action):
        self.policy_table[state] = action

In [3]:
class QFunction:

    """ Update the Q-value of (state, action) by delta """

    def update(self, state, action, delta):
        pass

    """ Get a Q value for a given state-action pair """

    def get_q_value(self, state, action):
        pass

    """ Save a policy to a specified filename """
    def save_policy(self, filename):
        pass

    """ Load a policy from a specified filename """
    def load_policy(self, filename):
        pass

    """ Return the action with the maximum Q-value """
    def get_argmax_q(self, state, actions):
        (argmax_q, max_q) = self.get_max_pair(state, actions)
        return argmax_q

    """ Return the maximum Q-value in this Q-function """
    def get_max_q(self, state, actions):
        (argmax_q, max_q) = self.get_max_pair(state, actions)
        return max_q

    """ Return a pair containing the action and Q-value, where the
        action has the maximum Q-value in state
    """
    def get_max_pair(self, state, actions):
        arg_max_q = None
        max_q = float("-inf")
        for action in actions:
            value = self.get_q_value(state, action)
            if max_q < value:
                arg_max_q = action
                max_q = value
        return (arg_max_q, max_q)


class QTable(QFunction):
    def __init__(self, alpha=0.1, default_q_value=0.0):
        self.qtable = defaultdict(lambda: default_q_value)
        self.alpha = alpha

    def update(self, state, action, delta):
        self.qtable[(state, action)] = self.qtable[(state, action)] + self.alpha * delta

    def get_q_value(self, state, action):
        return self.qtable[(state, action)]

    def save(self, filename):
        with open(filename, "w") as file:
            serialised = {str(key): value for key, value in self.qtable.items()}
            json.dump(serialised, file)

    def load(self, filename, default=0.0):
        with open(filename, "r") as file:
            serialised = json.load(file)
            self.qtable = defaultdict(
                lambda: default,
                {tuple(eval(key)): value for key, value in serialised.items()},
            )


class ValueFunction():

    def update(self, state, value):
        pass

    def merge(self, value_table):
        pass

    def get_value(self, state):
        pass

    """ Return the Q-value of action in state """
    def get_q_value(self, mdp, state, action):
        q_value = 0.0
        for (new_state, probability) in mdp.get_transitions(state, action):
            reward = mdp.get_reward(state, action, new_state)
            q_value += probability * (
                reward
                + (mdp.get_discount_factor() * self.get_value(new_state))
            )

        return q_value

    """ Return a policy from this value function """

    def extract_policy(self, mdp):
        policy = TabularPolicy()
        for state in mdp.get_states():
            max_q = float("-inf")
            for action in mdp.get_actions(state):
                q_value = self.get_q_value(mdp, state, action)

                # If this is the maximum Q-value so far,
                # set the policy for this state
                if q_value > max_q:
                    policy.update(state, action)
                    max_q = q_value

        return policy


class TabularValueFunction(ValueFunction):
    def __init__(self, default=0.0):
        self.value_table = defaultdict(lambda: default)

    def update(self, state, value):
        self.value_table[state] = value

    def merge(self, value_table):
        for state in value_table.value_table.keys():
            self.update(state, value_table.get_value(state))

    def get_value(self, state):
        return self.value_table[state]


class PolicyIteration:
    def __init__(self, mdp, policy):
        self.mdp = mdp
        self.policy = policy
    
    def policy_evaluation(self, policy, values, theta=0.001):
        
        while True:
            delta= 0.0
            new_values = TabularPolicy()
            for state in self.mdp.get_states():
                # calclate the value of V(s)
                actions = self.mdp.get_actions(state)
                old_value = values.get_value(state)
                new_value = values.get_q_value(
                    self.mdp, state, policy.select_action(state, actions)
                )
                values.update(state, new_value)
                delta = max(delta, abs(old_value - new_value))
            
            # terminate if the value function has converged
            if delta < theta:
                break
    
    """ Implmentation of policy iteration.  Returns the number of iterations executed """
    
    def policy_iteration(self, max_iteration=100, theta=0.001):
        
        # crate a value function to hold details
        values = TabularValueFunction()
        
        for i in range(1, max_iteration + 1):
            policy_changed = False
            values = self.policy_evaluation(self.policy, values, theta)
            for state in self.mdp.get_states():
                actions = self.mdp.get_actions(state)
                old_action = self.policy.select_action(state, actions)
                
                q_values = QTable()