In [171]:
import numpy as np
from collections import defaultdict

In [10]:
from typing import Callable, Mapping, Tuple, TypeVar, Set

In [12]:
S = TypeVar('S')
A = TypeVar('A')
MDPTransitions = Mapping[S, Mapping[A, Mapping[S, float]]]
MDPActions = Mapping[S, Set[A]]
MDPRewards = Mapping[S, Mapping[A, float]]



In [155]:
def get_influence_tree(transitions) -> Mapping[S, Set[S]]:
    '''
    returns a mapping from state to all states that depend on that state in bellman equantions
    '''
    influence_tree = defaultdict(set)
    for state in transitions:
        for action in transitions[state]:
            for next_state in transitions[state][action]:
                influence_tree[next_state].add(state)
    return influence_tree

In [219]:
def value_iteration(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                   discount: float, vi_method: str='normal', k=None) -> Mapping[S, float]:
    next_value_function = {s: 0 for s in actions.keys()}
    base_value_function = None

    num_iter = 0
    if vi_method == 'normal':
        while base_value_function is None or\
                not check_value_fuction_equivalence(base_value_function, next_value_function):
            num_iter += 1
            base_value_function = next_value_function
            next_value_function = iterate_on_value_function(actions, transitions, rewards, base_value_function, 
                                                            discount)
    elif vi_method == 'random-k':
        while base_value_function is None or\
                not check_value_fuction_equivalence(base_value_function, next_value_function):
            num_iter += 1
            base_value_function = next_value_function

            next_value_function = random_k_iterate_on_value_function(actions, transitions, rewards, base_value_function, 
                                                            discount, k)
    elif vi_method == 'influence-tree':
        influence_tree = get_influence_tree(transitions)
        next_states_to_update = set(actions.keys())
        while len(next_states_to_update) > 0:
            num_iter += 1
            base_value_function = next_value_function
            next_value_function, updated_states = iterate_on_value_function_specific_states(actions, transitions,
                                                                                            rewards,
                                                                                            base_value_function, 
                                                            discount, next_states_to_update)
            next_states_to_update = set()
            for state in updated_states:
                next_states_to_update.update(influence_tree[state])
    elif vi_method == 'cyclic-vi':
        while base_value_function is None or\
                not check_value_fuction_equivalence(base_value_function, next_value_function):
            num_iter += 1
            base_value_function = next_value_function
            next_value_function = cycle_iterate_on_value_function(actions, transitions, rewards, 
                                                                  base_value_function, 
                                                                  discount)
    elif vi_method == 'cyclic-vi-rp':
        while base_value_function is None or\
                not check_value_fuction_equivalence(base_value_function, next_value_function):
            num_iter += 1
            base_value_function = next_value_function
            next_value_function = cycle_iterate_on_value_function_rp(actions, transitions, rewards, 
                                                                  base_value_function, 
                                                                  discount)
    else:
        raise NotImplemented(f'have not implemented {vi_method} value iteration yet')
    return base_value_function, num_iter


def iterate_on_value_function(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                              base_vf: Mapping[S, float], discount: float) -> Mapping[S, float]:
    new_vf = {}
    for s in actions.keys():
        action_values = [(action, extract_value_of_action(actions, transitions, rewards, 
                                                          action, s, base_vf, discount)) for action in actions[s]]
        best_action_reward = min([x[1] for x in action_values])
        new_vf[s] = best_action_reward
    return new_vf

def cycle_iterate_on_value_function(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                                    base_vf: Mapping[S, float], discount: float) -> Mapping[S, float]:
    new_vf = base_vf.copy()
    for s in actions.keys():
        action_values = [(action, extract_value_of_action(actions, transitions, rewards, 
                                                          action, s, new_vf, discount)) for action in actions[s]]
        best_action_reward = min([x[1] for x in action_values])
        new_vf[s] = best_action_reward
    return new_vf


def cycle_iterate_on_value_function_rp(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                                       base_vf: Mapping[S, float], discount: float) -> Mapping[S, float]:
    new_vf = base_vf.copy()
    states = list(actions.keys())
    np.random.shuffle(states)
    for s in states:
        action_values = [(action, extract_value_of_action(actions, transitions, rewards, 
                                                          action, s, new_vf, discount)) for action in actions[s]]
        best_action_reward = min([x[1] for x in action_values])
        new_vf[s] = best_action_reward
    return new_vf

def random_k_iterate_on_value_function(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                              base_vf: Mapping[S, float], discount: float, k: int) -> Mapping[S, float]:
    new_vf = {}
    states = list(actions.keys())
    states_to_update_idx = np.random.choice(range(len(states)), size=k)
    states_to_update = [states[idx] for idx in states_to_update_idx]
    for s in states_to_update:
        action_values = [(action, extract_value_of_action(actions, transitions, rewards, 
                                                          action, s, base_vf, discount)) for action in actions[s]]
        best_action_reward = min([x[1] for x in action_values])
        new_vf[s] = best_action_reward
    for s in set(actions.keys()) - set(states_to_update):
        new_vf[s] = base_vf[s]
    return new_vf

def iterate_on_value_function_specific_states(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                                              base_vf: Mapping[S, float], discount: float, 
                                              states_to_update: Set[S]) -> Mapping[S, float]:
    new_vf = {}
    updated_states = set()
    states = list(actions.keys())
    for s in states_to_update:
        action_values = [(action, extract_value_of_action(actions, transitions, rewards, 
                                                          action, s, base_vf, discount)) for action in actions[s]]
        best_action_reward = min([x[1] for x in action_values])
        new_vf[s] = best_action_reward
        if abs(new_vf[s] - base_vf[s]) > 1e-8:
            updated_states.add(s)
    for s in set(actions.keys()) - set(states_to_update):
        new_vf[s] = base_vf[s]
    return new_vf, updated_states


def extract_value_of_action(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                            action: A, state: S, value_function, discount: float):
    return rewards[state][action] + discount * sum([p * value_function[s_prime]
                                                    for s_prime, p in
                                                    transitions[state][action].items()])


def check_value_fuction_equivalence(v1, v2, epsilon=1e-8) -> bool:
    assert v1.keys() == v2.keys(), "comparing policies with different state spaces"
    for state in v1:
        if not abs(v1[state] - v2[state]) <= epsilon:
            return False
    return True


def check_policy_equivalence(p1, p2) -> bool:
    assert p1.keys() == p2.keys(), "comparing policies with different state spaces"
    for state in p1:
        if p1[state] != p2[state]:
            return False
    return True


def get_greedy_policy(actions: MDPActions, transitions: MDPTransitions, rewards: MDPRewards,
                      value_function: Mapping[S, float], terminal_states: Set[S], 
                     discount: float) -> Mapping[S, A]:
    policy = {}
    non_terminal_states = set(actions.keys()) - terminal_states
    for s in non_terminal_states:
        actions_rewards = {}
        for action in actions[s]:
            actions_rewards[action] = extract_value_of_action(actions, transitions, rewards,
                                                              action, s, value_function, discount)
        policy[s] = {(min(actions_rewards, key=actions_rewards.get), 1)}
    for s in terminal_states:
        policy[s] = {(list(actions[s])[0], 1)}
    return policy


In [220]:
'''
Maze Runner Problem
'''

maze_runner_actions = {
    0: {'s', 'j'},
    1: {'s', 'j'},
    2: {'s', 'j'},
    3: {'s', 'j'},
    4: {'s',},
    5: {'stay'}
}

maze_runner_transitions = {
    0: {'s': {1: 1}, 'j': {2: 0.5, 3: 0.25, 4: 0.125, 5:0.125}},
    1: {'s': {2: 1}, 'j': {3: 0.5, 4: 0.25, 5:0.25}},
    2: {'s': {3: 1}, 'j': {4: 0.5, 5:0.5}},
#     3: {'s': {4: 1}, 'j': {4: 0.5, 5:0.5}},
    3: {'s': {4: 1}, 'j': {5:1}},
    4: {'s': {5: 1}},
    5: {'stay': {5: 1}}
}

maze_runner_rewards = {
    0: {'s': 0, 'j': 0},
    1: {'s': 0, 'j': 0},
    2: {'s': 0, 'j': 0},
    3: {'s': 0, 'j': 0},
    4: {'s': 1},
    5: {'stay': 0}
}


In [221]:
pure_vi, num_iter = value_iteration(actions=maze_runner_actions, transitions=maze_runner_transitions, rewards=maze_runner_rewards,
               discount=0.9)
print(pure_vi, num_iter)
random_k_vi, num_iter_random_k = value_iteration(actions=maze_runner_actions, transitions=maze_runner_transitions, rewards=maze_runner_rewards,
               discount=0.9, vi_method='random-k', k=3)
print(random_k_vi, num_iter_random_k)

tree_vi, num_iter_tree = value_iteration(actions=maze_runner_actions, transitions=maze_runner_transitions, rewards=maze_runner_rewards,
               discount=0.9, vi_method='influence-tree')
print(tree_vi, num_iter_tree)

cyclic_vi, num_iter_cyclic = value_iteration(actions=maze_runner_actions, transitions=maze_runner_transitions, rewards=maze_runner_rewards,
               discount=0.9, vi_method='cyclic-vi')
print(cyclic_vi, num_iter_cyclic)

rp_cyclic_vi, rp_num_iter_cyclic = value_iteration(actions=maze_runner_actions, transitions=maze_runner_transitions, rewards=maze_runner_rewards,
               discount=0.9, vi_method='cyclic-vi-rp')
print(rp_cyclic_vi, rp_num_iter_cyclic)


{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 0.0} 2
{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0} 1
{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 0.0} 2
{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 0.0} 2
{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 0.0} 2


In [222]:
def get_maze(x_dim, y_dim, num_terminal, seed=None, deterministic=True):
    if seed is not None:
        np.random.seed(seed)
    states = [(x,y) for x in range(x_dim) for y in range(y_dim)]
    terminal_states_idx = np.random.choice(list(range(len(states))), size=num_terminal)
    terminal_states = [states[idx] for idx in terminal_states_idx]
    actions = {'l','r','u','d'}
    state_transitions = {}
    state_actions = {}
    state_rewards = {}
    for state in states:
        state_transitions[state] = {}
        state_rewards[state] = {}
        if state in terminal_states:
            state_actions[state] = {'stay'}
        else:
            state_actions[state] = actions
        if deterministic:
            for action in state_actions[state]:
                next_coord = None
                if action == 'stay':
                    next_coord = state
                elif action == 'l':
                    next_coord = (state[0], max(0, state[1] - 1))
                if action == 'r':
                    next_coord = (state[0], min(x_dim-1, state[1] + 1))
                if action == 'u':
                    next_coord = (min(y_dim-1, state[0] + 1), state[1])
                if action == 'd':
                    next_coord = (max(0, state[0] - 1), state[1])
                state_transitions[state][action] = {next_coord: 1}
                if next_coord != state and next_coord in terminal_states:
                    state_rewards[state][action] = np.random.choice([-1, 0, 1])
                else:
                    state_rewards[state][action] = 0
        else:
            raise NotImplemented('have not implemented non-deterministic transitions')
    return state_actions, state_transitions, state_rewards, terminal_states

In [223]:
maze_actions, maze_transitions, maze_rewards, terminal_maze_state = get_maze(4,4, 1, seed=0)
print(maze_actions)
print()
print(maze_transitions)
print()
print(maze_rewards)
print()



{(0, 0): {'u', 'r', 'l', 'd'}, (0, 1): {'u', 'r', 'l', 'd'}, (0, 2): {'u', 'r', 'l', 'd'}, (0, 3): {'u', 'r', 'l', 'd'}, (1, 0): {'u', 'r', 'l', 'd'}, (1, 1): {'u', 'r', 'l', 'd'}, (1, 2): {'u', 'r', 'l', 'd'}, (1, 3): {'u', 'r', 'l', 'd'}, (2, 0): {'u', 'r', 'l', 'd'}, (2, 1): {'u', 'r', 'l', 'd'}, (2, 2): {'u', 'r', 'l', 'd'}, (2, 3): {'u', 'r', 'l', 'd'}, (3, 0): {'stay'}, (3, 1): {'u', 'r', 'l', 'd'}, (3, 2): {'u', 'r', 'l', 'd'}, (3, 3): {'u', 'r', 'l', 'd'}}

{(0, 0): {'u': {(1, 0): 1}, 'r': {(0, 1): 1}, 'l': {(0, 0): 1}, 'd': {(0, 0): 1}}, (0, 1): {'u': {(1, 1): 1}, 'r': {(0, 2): 1}, 'l': {(0, 0): 1}, 'd': {(0, 1): 1}}, (0, 2): {'u': {(1, 2): 1}, 'r': {(0, 3): 1}, 'l': {(0, 1): 1}, 'd': {(0, 2): 1}}, (0, 3): {'u': {(1, 3): 1}, 'r': {(0, 3): 1}, 'l': {(0, 2): 1}, 'd': {(0, 3): 1}}, (1, 0): {'u': {(2, 0): 1}, 'r': {(1, 1): 1}, 'l': {(1, 0): 1}, 'd': {(0, 0): 1}}, (1, 1): {'u': {(2, 1): 1}, 'r': {(1, 2): 1}, 'l': {(1, 0): 1}, 'd': {(0, 1): 1}}, (1, 2): {'u': {(2, 2): 1}, 'r': {(1, 

In [226]:
pure_vi, num_iter = value_iteration(actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               discount=0.9)
print(pure_vi)
print(num_iter)
policy_p_vi = get_greedy_policy(value_function=pure_vi, actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               terminal_states=set(terminal_maze_state), discount=0.9)
print(policy_p_vi)
print()


k_vi, k_num_iter = value_iteration(actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               discount=0.9, vi_method='random-k', k=5)

print(k_vi)
print(k_num_iter)
policy_k_vi = get_greedy_policy(value_function=k_vi, actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               terminal_states=set(terminal_maze_state), discount=0.9)
print(policy_k_vi)
print()


tree_vi, tree_num_iter = value_iteration(actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               discount=0.9, vi_method='influence-tree')

print(tree_vi)
print(tree_num_iter)
policy_tree_vi = get_greedy_policy(value_function=tree_vi, actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               terminal_states=set(terminal_maze_state), discount=0.9)
print(policy_tree_vi)
print()

cyclic_vi, cyclic_num_iter = value_iteration(actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               discount=0.9, vi_method='cyclic-vi')

print(cyclic_vi)
print(cyclic_num_iter)
policy_cyclic_vi = get_greedy_policy(value_function=cyclic_vi, actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               terminal_states=set(terminal_maze_state), discount=0.9)
print(policy_cyclic_vi)
print()

rp_cyclic_vi, rp_cyclic_num_iter = value_iteration(actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               discount=0.9, vi_method='cyclic-vi-rp')

print(rp_cyclic_vi)
print(rp_cyclic_num_iter)
rp_policy_cyclic_vi = get_greedy_policy(value_function=rp_cyclic_vi, actions=maze_actions, transitions=maze_transitions, rewards=maze_rewards,
               terminal_states=set(terminal_maze_state), discount=0.9)
print(rp_policy_cyclic_vi)
print()

{(0, 0): -0.6561000000000001, (0, 1): -0.7290000000000001, (0, 2): -0.6561000000000001, (0, 3): -0.5904900000000002, (1, 0): -0.7290000000000001, (1, 1): -0.81, (1, 2): -0.7290000000000001, (1, 3): -0.6561000000000001, (2, 0): -0.81, (2, 1): -0.9, (2, 2): -0.81, (2, 3): -0.7290000000000001, (3, 0): 0.0, (3, 1): -1.0, (3, 2): -0.9, (3, 3): -0.81}
7
{(3, 2): {('l', 1)}, (0, 0): {('u', 1)}, (1, 3): {('u', 1)}, (0, 2): {('u', 1)}, (2, 1): {('u', 1)}, (2, 3): {('u', 1)}, (1, 0): {('u', 1)}, (0, 3): {('u', 1)}, (0, 1): {('u', 1)}, (1, 2): {('u', 1)}, (3, 3): {('l', 1)}, (3, 1): {('l', 1)}, (2, 0): {('r', 1)}, (2, 2): {('u', 1)}, (1, 1): {('u', 1)}, (3, 0): {('stay', 1)}}

{(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (1, 0): 0, (1, 1): 0, (1, 2): 0, (1, 3): 0, (2, 0): 0, (2, 1): 0, (2, 2): 0, (2, 3): 0, (3, 0): 0, (3, 1): 0, (3, 2): 0, (3, 3): 0}
1
{(3, 2): {('u', 1)}, (0, 0): {('u', 1)}, (1, 3): {('u', 1)}, (0, 2): {('u', 1)}, (2, 1): {('u', 1)}, (2, 3): {('u', 1)}, (1, 0): {('u', 1)}, (0, 3