# Multi-armed Contextual Bandits: Stochastic

In [1]:
import numpy as np

## Create environment

In [2]:
def create_environment_num_states():
    """Creates environment number of states.

    Returns:
        num_states: int, number of states.
    """
    num_states = 8

    return num_states

In [3]:
def create_environment_num_bandits():
    """Creates environment number of bandits.

    Returns:
        num_bandits: int, number of bandits.
    """
    num_bandits = 10

    return num_bandits

In [4]:
def create_environment_bandit_means(num_states, num_bandits):
    """Creates environment bandit means.

    Args:
        num_states: int, number of states.
        num_bandits: int, number of bandits.
    Returns:
        global_bandit_mean_mean: float, the global mean of means across all
            bandits.
        global_bandit_mean_variance: float, the global variance of means
            across all bandits.
        bandit_mean: array[float], the means of each bandit.
    """
    global_bandit_mean_mean = 0.0
    global_bandit_mean_variance = 1.0
    
    bandit_mean = np.random.normal(
        loc=global_bandit_mean_mean,
        scale=np.sqrt(global_bandit_mean_variance),
        size=num_states * num_bandits)

    bandit_mean = bandit_mean.reshape(num_states, num_bandits)

    return global_bandit_mean_mean, global_bandit_mean_variance, bandit_mean

In [5]:
def create_environment_bandit_variances(num_states, num_bandits):
    """Creates environment bandit variances.

    Args:
        num_states: int, number of states.
        num_bandits: int, number of bandits.
    Returns:
        global_bandit_variance_mean: float, the global variance of variances
            across all bandits.
        global_bandit_variance_variance: float, the global variance of variances
            across all bandits.
        bandit_variance: array[float], the variances of each bandit.
    """
    global_bandit_variance_mean = 1.0
    global_bandit_variance_variance = 0.0

    bandit_variance = np.random.normal(
        loc=global_bandit_variance_mean,
        scale=np.sqrt(global_bandit_variance_variance),
        size=num_states * num_bandits)

    bandit_variance = bandit_variance.reshape(num_states, num_bandits)

    return (global_bandit_variance_mean,
            global_bandit_variance_variance,
            bandit_variance)

In [6]:
def create_environment_bandit_change_arrays(num_bandits):
    """Creates environment bandit change arrays.

    Args:
        num_bandits: int, number of bandits.
    Returns:
        bandit_change_frequencies: array[int], how often each
            bandit's statistics changes.
        bandit_change_counter: array[int], the change
            counter of each bandit.
    """
    bandit_change_frequencies = np.repeat(
        a=201, repeats=num_bandits)

    bandit_change_counter = np.zeros(
        shape=[num_bandits], dtype=np.int64)

    return (bandit_change_frequencies,
            bandit_change_counter)

In [7]:
def create_environment():
    """Creates environment.

    Returns:
        num_states: int, number of states.
        num_bandits: int, number of bandits.
        global_bandit_mean_mean: float, the global mean of means across all
            bandits.
        global_bandit_mean_variance: float, the global variance of means
            across all bandits.
        bandit_mean: array[float], the means of each bandit.
        global_bandit_variance_mean: float, the global variance of variances
            across all bandits.
        global_bandit_variance_variance: float, the global variance of variances
            across all bandits.
        bandit_variance: array[float], the variances of each bandit.
        bandit_change_frequencies: array[int], how often each
            bandit's statistics changes.
        bandit_change_counter: array[int], the change
            counter of each bandit.
    """
    num_states = create_environment_num_states()

    num_bandits = create_environment_num_bandits()

    (global_bandit_mean_mean,
     global_bandit_mean_variance,
     bandit_mean) = create_environment_bandit_means(num_states, num_bandits)

    (global_bandit_variance_mean,
     global_bandit_variance_variance,
     bandit_variance) = create_environment_bandit_variances(
        num_states, num_bandits)

    (bandit_change_frequencies,
     bandit_change_counter) = create_environment_bandit_change_arrays(
        num_bandits)

    return (num_states,
            num_bandits,
            global_bandit_mean_mean,
            global_bandit_mean_variance,
            bandit_mean,
            global_bandit_variance_mean,
            global_bandit_variance_variance,
            bandit_variance,
            bandit_change_frequencies,
            bandit_change_counter)

## Set hyperparameters

In [8]:
def set_hyperparameters():
    """Sets hyperparameters.

    Returns:
        num_iterations: int, number of iterations.
        alpha: float, alpha > 0, learning rate.
        epsilon: float, 0 <= epsilon <= 1, exploitation-exploration trade-off,
            higher means more exploration.
        action_selection_type: int, action selection type (greedy,
            epsilon-greedy, upper-confidence-bound).
        action_value_update_type: int, action value update type (
            sample-average, biased constant step-size, unbiased constant
            step-size).
    """
    num_iterations = 2000
    alpha = 0.1
    epsilon = 0.1
    action_selection_type = 1
    action_value_update_type = 2

    return (num_iterations,
            alpha,
            epsilon,
            action_selection_type,
            action_value_update_type)

## Create value function and policy arrays

In [9]:
def create_action_arrays(num_states, num_bandits):
    """Creates action arrays.

    Args:
        num_bandits: int, number of bandits.
    Returns:
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        action_count: array[int], counts the number of times each bandit was
            actioned.
        action_trace: array[float], keeps track of the reward trace for each
            bandit.
    """
    q = np.zeros(shape=[num_states, num_bandits], dtype=np.float64)

    action_count = np.zeros(shape=[num_states, num_bandits], dtype=np.int64)

    action_trace = np.zeros(shape=[num_states, num_bandits], dtype=np.float64)

    return q, action_count, action_trace

In [10]:
def create_policy_arrays(num_states, num_bandits):
    """Creates policy arrays.

    Args:
        num_bandits: int, number of bandits.
    Returns:
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
    """
    policy = np.repeat(a=1.0 / num_bandits, repeats=num_states * num_bandits)
    policy = policy.reshape(num_states, num_bandits)

    return policy

## Create algorithm

In [11]:
# Set random seed so that everything is reproducible
np.random.seed(seed=0)

In [12]:
def loop_through_iterations(
        num_iterations,
        num_states,
        num_bandits,
        bandit_mean,
        bandit_variance,
        bandit_change_frequencies,
        bandit_change_counter,
        global_bandit_mean_mean,
        global_bandit_mean_variance,
        global_bandit_variance_mean,
        global_bandit_variance_variance,
        q,
        action_count,
        action_trace,
        policy,
        alpha,
        epsilon,
        action_selection_type,
        action_value_update_type):
    """Loops through iterations to iteratively update policy.

    Args:
        num_iterations: int, number of iterations.
        num_states: int, number of states.
        num_bandits: int, number of bandits.
        bandit_mean: array[float], the means of each bandit.
        bandit_variance: array[float], the variances of each bandit.
        bandit_change_frequencies: array[int], how often each
            bandit's statistics changes.
        bandit_change_counter: array[int], the change
            counter of each bandit.
        global_bandit_mean_mean: float, the global mean of means across all
            bandits.
        global_bandit_mean_variance: float, the global variance of means
            across all bandits.
        global_bandit_variance_mean: float, the global variance of variances
            across all bandits.
        global_bandit_variance_variance: float, the global variance of variances
            across all bandits.
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        action_count: array[int], counts the number of times each bandit was
            actioned.
        action_trace: array[float], keeps track of the reward trace for each
            bandit.
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
        alpha: float, alpha > 0, learning rate.
        epsilon: float, 0 <= epsilon <= 1, exploitation-exploration trade-off,
            higher means more exploration.
        action_selection_type: int, action selection type (greedy,
            epsilon-greedy, upper-confidence-bound).
        action_value_update_type: int, action value update type (
            sample-average, biased constant step-size, unbiased constant
            step-size).
    Returns:
        bandit_mean: array[float], the means of each bandit.
        bandit_variance: array[float], the variances of each bandit.
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
    """
    # Loop through iterations until termination
    for t in range(0, num_iterations):
        # Get random state
        s_idx = np.random.randint(num_states)

        # Choose policy by epsilon-greedy choosing from state-action-value
        # function
        policy = update_policy_from_q(
            s_idx,
            num_bandits,
            q,
            action_count,
            t + 1,
            epsilon,
            action_selection_type,
            policy)

        # Get action
        a_idx = np.random.choice(
            a=num_bandits, p=policy[s_idx, :])

        # Get reward from action
        reward = np.random.normal(
            loc=bandit_mean[s_idx, a_idx],
            scale=np.sqrt(bandit_variance[s_idx, a_idx]))

        # Update action count
        action_count[s_idx, a_idx] += 1

        # Update action-value function
        delta = reward - q[s_idx, a_idx]
        if action_value_update_type == 0:  # sample-average method
            learning_rate = 1.0 / action_count[s_idx, a_idx]
            q[s_idx, a_idx] += learning_rate * delta
        elif action_value_update_type == 1:  # biased constant step-size
            q[s_idx, a_idx] += alpha * delta
        elif action_value_update_type == 2:  # unbiased constant step-size
            # Update action trace
            trace_diff = 1.0 - action_trace[s_idx, a_idx]
            action_trace[s_idx, a_idx] += alpha * trace_diff

            learning_rate = alpha / action_trace[s_idx, a_idx]
            q[s_idx, a_idx] += learning_rate * delta

        # Mutate bandit statistics
        for i in range(num_bandits):
            if bandit_change_frequencies[i] > 0:
                bandit_change_counter[i] += 1

                if bandit_change_counter[i] == bandit_change_frequencies[i]:
                    bandit_mean[:, i] = np.random.normal(
                        loc=global_bandit_mean_mean,
                        scale=np.sqrt(global_bandit_mean_variance),
                        size=num_states)
                    bandit_variance[:, i] = np.random.normal(
                        loc=global_bandit_variance_mean,
                        scale=np.sqrt(global_bandit_variance_variance),
                        size=num_states)

                    bandit_change_counter[i] = 0

    return bandit_mean, bandit_variance, q, policy

In [13]:
def update_policy_from_q(
        s_idx,
        num_bandits,
        q,
        action_count,
        iteration_count,
        epsilon,
        action_selection_type,
        policy):
    """Updates policy epilson-greedily from state-action-value function.

    Args:
        s_idx: int, current state index.
        num_bandits: int, number of bandits.
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        action_count: array[int], counts the number of times each bandit was
            actioned.
        iteration_count: int, current loop iteration count.
        epsilon: float, 0 <= epsilon <= 1, exploitation-exploration trade-off,
            higher means more exploration.
        action_selection_type: int, action selection type (greedy,
            epsilon-greedy, upper-confidence-bound).
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
    Returns:
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
    """
    # Calculate action value depending on action selection type
    if action_selection_type == 0 or action_selection_type == 1:
        # Greedy or epsilon-greedy
        action_value = q[s_idx, :]
    elif action_selection_type == 2:
        # Upper-confidence-bound
        min_count_idx = np.argmin(a=q[s_idx, :])
        if min_count_idx == 0:
            policy[s_idx, :] = np.where(
                np.arange(num_bandits) == min_count_idx, 1.0, 0.0)
            return policy
        else:
            action_value = q[s_idx, :] + epsilon * np.sqrt(
                np.log(iteration_count) / action_count)

    # Save max action value and find the number of actions that have the same
    # max action value
    max_action_value = np.max(a=q[s_idx, :])
    max_action_count = np.count_nonzero(a=q[s_idx, :] == max_action_value)

    # Apportion policy probability across ties equally for state-action pairs
    # that have the same value and zero otherwise
    if action_selection_type == 1:
        # Epsilon-greedy
        if max_action_count == num_bandits:
            max_policy_prob_per_action = 1.0 / max_action_count
            remain_prob_per_action = 0.0
        else:
            max_policy_prob_per_action = (1.0 - epsilon) / max_action_count
            remain_prob_per_action = epsilon / (num_bandits - max_action_count)
    elif action_selection_type == 0 or action_selection_type == 2:
        # Greedy or upper-confidence-bound
        max_policy_prob_per_action = 1.0 / max_action_count
        remain_prob_per_action = 0.0

    policy[s_idx, :] = np.where(
        action_value == max_action_value,
        max_policy_prob_per_action,
        remain_prob_per_action)

    return policy

In [14]:
def stochastic_multi_armed_contextual_bandits(
        num_iterations,
        num_states,
        num_bandits,
        bandit_mean,
        bandit_variance,
        bandit_change_frequencies,
        bandit_change_counter,
        global_bandit_mean_mean,
        global_bandit_mean_variance,
        global_bandit_variance_mean,
        global_bandit_variance_variance,
        q,
        action_count,
        action_trace,
        policy,
        alpha,
        epsilon,
        action_selection_type,
        action_value_update_type):
    """Loops through iterations to iteratively update policy.

    Args:
        num_iterations: int, number of iterations.
        num_states: int, number of states.
        num_bandits: int, number of bandits.
        bandit_mean: array[float], the means of each bandit.
        bandit_variance: array[float], the variances of each bandit.
        bandit_change_frequencies: array[int], how often each
            bandit's statistics changes.
        bandit_change_counter: array[int], the change
            counter of each bandit.
        global_bandit_mean_mean: float, the global mean of means across all
            bandits.
        global_bandit_mean_variance: float, the global variance of means
            across all bandits.
        global_bandit_variance_mean: float, the global variance of variances
            across all bandits.
        global_bandit_variance_variance: float, the global variance of variances
            across all bandits.
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        action_count: array[int], counts the number of times each bandit was
            actioned.
        action_trace: array[float], keeps track of the reward trace for each
            bandit.
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
        alpha: float, alpha > 0, learning rate.
        epsilon: float, 0 <= epsilon <= 1, exploitation-exploration trade-off,
            higher means more exploration.
        action_selection_type: int, action selection type (greedy,
            epsilon-greedy, upper-confidence-bound).
        action_value_update_type: int, action value update type (
            sample-average, biased constant step-size, unbiased constant
            step-size).
    Returns:
        bandit_mean: array[float], the means of each bandit.
        bandit_variance: array[float], the variances of each bandit.
        q: array[float], keeps track of the estimated value of each bandit in
            each state, Q(s, b).
        policy: array[float], learned stochastic policy of which
            bandit to action given a state.
    """
    # Loop through iterations and update the policy
    bandit_mean, bandit_variance, q, policy = loop_through_iterations(
        num_iterations,
        num_states,
        num_bandits,
        bandit_mean,
        bandit_variance,
        bandit_change_frequencies,
        bandit_change_counter,
        global_bandit_mean_mean,
        global_bandit_mean_variance,
        global_bandit_variance_mean,
        global_bandit_variance_variance,
        q,
        action_count,
        action_trace,
        policy,
        alpha,
        epsilon,
        action_selection_type,
        action_value_update_type)

    return bandit_mean, bandit_variance, q, policy

## Run algorithm

In [15]:
def run_algorithm():
    """Runs the algorithm."""
    (num_states,
     num_bandits,
     global_bandit_mean_mean,
     global_bandit_mean_variance,
     bandit_mean,
     global_bandit_variance_mean,
     global_bandit_variance_variance,
     bandit_variance,
     bandit_change_frequencies,
     bandit_change_counter) = create_environment()

    (num_iterations,
     alpha,
     epsilon,
     action_selection_type,
     action_value_update_type) = set_hyperparameters()

    q, action_count, action_trace = create_action_arrays(
        num_states, num_bandits)

    policy = create_policy_arrays(num_states, num_bandits)

    # Print initial arrays
    print("\nInitial bandit mean")
    print(bandit_mean)

    print("\nInitial bandit variance")
    print(bandit_variance)

    print("\nInitial state-action value function")
    print(q)

    print("\nInitial policy")
    print(policy)

    # Run on policy temporal difference sarsa
    (bandit_mean,
     bandit_variance,
     q,
     policy) = stochastic_multi_armed_contextual_bandits(
        num_iterations,
        num_states,
        num_bandits,
        bandit_mean,
        bandit_variance,
        bandit_change_frequencies,
        bandit_change_counter,
        global_bandit_mean_mean,
        global_bandit_mean_variance,
        global_bandit_variance_mean,
        global_bandit_variance_variance,
        q,
        action_count,
        action_trace,
        policy,
        alpha,
        epsilon,
        action_selection_type,
        action_value_update_type)

    # Print final results
    print("\nFinal bandit mean")
    print(bandit_mean)

    print("\nFinal bandit variance")
    print(bandit_variance)

    print("\nFinal state-action value function")
    print(q)

    print("\nFinal policy")
    print(policy)

In [16]:
run_algorithm()


Initial bandit mean
[[ 1.76405235  0.40015721  0.97873798  2.2408932   1.86755799 -0.97727788
   0.95008842 -0.15135721 -0.10321885  0.4105985 ]
 [ 0.14404357  1.45427351  0.76103773  0.12167502  0.44386323  0.33367433
   1.49407907 -0.20515826  0.3130677  -0.85409574]
 [-2.55298982  0.6536186   0.8644362  -0.74216502  2.26975462 -1.45436567
   0.04575852 -0.18718385  1.53277921  1.46935877]
 [ 0.15494743  0.37816252 -0.88778575 -1.98079647 -0.34791215  0.15634897
   1.23029068  1.20237985 -0.38732682 -0.30230275]
 [-1.04855297 -1.42001794 -1.70627019  1.9507754  -0.50965218 -0.4380743
  -1.25279536  0.77749036 -1.61389785 -0.21274028]
 [-0.89546656  0.3869025  -0.51080514 -1.18063218 -0.02818223  0.42833187
   0.06651722  0.3024719  -0.63432209 -0.36274117]
 [-0.67246045 -0.35955316 -0.81314628 -1.7262826   0.17742614 -0.40178094
  -1.63019835  0.46278226 -0.90729836  0.0519454 ]
 [ 0.72909056  0.12898291  1.13940068 -1.23482582  0.40234164 -0.68481009
  -0.87079715 -0.57884966 -0.31