In [1]:
import random
from collections import defaultdict

In [None]:
class Policy:
    def select_action(self, state, action):
        pass


class DeterministicPolicy(Policy):
    def update(self, state, action):
        pass


class TabularPolicy(DeterministicPolicy):
    def __init__(self, default_action=None):
        self.policy_table = defaultdict(lambda: default_action)

    def select_action(self, state, actions):
        return self.policy_table[state]

    def update(self, state, action):
        self.policy_table[state] = action

In [3]:
class ValueFunction():

    def update(self, state, value):
        pass

    def merge(self, value_table):
        pass

    def get_value(self, state):
        pass

    """ Return the Q-value of action in state """
    def get_q_value(self, mdp, state, action):
        q_value = 0.0
        for (new_state, probability) in mdp.get_transitions(state, action):
            reward = mdp.get_reward(state, action, new_state)
            q_value += probability * (
                reward
                + (mdp.get_discount_factor() * self.get_value(new_state))
            )

        return q_value

    """ Return a policy from this value function """

    def extract_policy(self, mdp):
        policy = TabularPolicy()
        for state in mdp.get_states():
            max_q = float("-inf")
            for action in mdp.get_actions(state):
                q_value = self.get_q_value(mdp, state, action)

                # If this is the maximum Q-value so far,
                # set the policy for this state
                if q_value > max_q:
                    policy.update(state, action)
                    max_q = q_value

        return policy


class TabularValueFunction(ValueFunction):
    def __init__(self, default=0.0):
        self.value_table = defaultdict(lambda: default)

    def update(self, state, value):
        self.value_table[state] = value

    def merge(self, value_table):
        for state in value_table.value_table.keys():
            self.update(state, value_table.get_value(state))

    def get_value(self, state):
        return self.value_table[state]

In [None]:
class PolicyIteration:
    def __init__(self, mdp, policy):
        self.mdp = mdp
        self.policy = policy
    
    def policy_evaluation(self, policy, values, theta=0.001):
        
        while True:
            delta= 0.0
            new_values = TabularPolicy()