In [1]:
using GittinsIndices
using Distributions
using Plots

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling GittinsIndices [528e2044-b1ef-47f1-a8f4-693ac85d91f9]
[91m[1mERROR: [22m[39mLoadError: ArgumentError: Package ProgressLogging [33c8b6b6-d38a-422a-b730-caa89a2f386c] is required but does not seem to be installed:
 - Run `Pkg.instantiate()` to install all recorded dependencies.

Stacktrace:
 [1] [0m[1m_require[22m[0m[1m([22m[90mpkg[39m::[0mBase.PkgId[0m[1m)[22m
[90m   @ [39m[90mBase[39m [90m./[39m[90m[4mloading.jl:1306[24m[39m
 [2] [0m[1m_require_prelocked[22m[0m[1m([22m[90muuidkey[39m::[0mBase.PkgId[0m[1m)[22m
[90m   @ [39m[90mBase[39m [90m./[39m[90m[4mloading.jl:1200[24m[39m
 [3] [0m[1mmacro expansion[22m
[90m   @ [39m[90m./[39m[90m[4mloading.jl:1180[24m[39m[90m [inlined][39m
 [4] [0m[1mmacro expansion[22m
[90m   @ [39m[90m./[39m[90m[4mlock.jl:223[24m[39m[90m [inlined][39m
 [5] [0m[1mrequire[22m[0m[1m([22m[90minto[39m::[0mModule, [90mmod[39m::

LoadError: Failed to precompile GittinsIndices [528e2044-b1ef-47f1-a8f4-693ac85d91f9] to /Users/ameliebuc/.julia/compiled/v1.8/GittinsIndices/jl_2srvJi.

In [8]:
function get_gittins_action(gittins_priors, gamma)
    gittins_indices = [
        calculate_bernoulli_gittins(
            alpha = alpha, 
            beta = beta, 
            gamma = gamma,
        ) for (alpha, beta) in gittins_priors
    ]
    gittins_action = argmax(gittins_indices)
    return gittins_action
end

function get_thompson_action(thompson_priors)
    thompson_sampling_probs = [
        alpha / (alpha + beta) 
        for (alpha, beta) in thompson_priors
    ]
    thompson_sampling_probs ./= sum(thompson_sampling_probs)
    thompson_action = rand(Categorical(thompson_sampling_probs))
    return thompson_action
end

function get_ucb_action(ucb_values, num_pulls, c = 1.0)
    ucb_bounds = [
        ucb_values[i] + c * sqrt(log(num_pulls) / max(1, sum(ucb_counts)))
        for i in 1:length(ucb_values)
    ]
    ucb_action = argmax(ucb_bounds)
    return ucb_action
end

function get_greedy_action(q_values, epsilon)
    if rand() < epsilon # Explore
        return rand(1:length(q_values)) #n_arms
    else
        return argmax(q_values) # Exploit
    end
end

get_ucb_action (generic function with 2 methods)

In [6]:
function gittins_vs(;num_arms, gamma, num_pulls)
    arms = [Bernoulli(rand(Float64)) for _ in 1:num_arms]
    
    gittins_priors = [[1, 1] for _ in 1:num_arms]
    gittins_rewards = []
    
    thompson_priors = [[1, 1] for _ in 1:num_arms]
    thompson_rewards = []
        
    ucb_values = [0.0 for _ in 1:num_arms]
    ucb_rewards = []
    
    q_values = zeros(num_arms)
    q_counts = zeros(Int, num_arms)
    epsilon_rewards = [[] for _ in 1:num_arms]

    for pull in 1:num_pulls
        # gittins indices
        gittins_action = get_gittins_action(gittins_priors, gamma)
        gittins_reward = rand(arms[gittins_action])
        push!(gittins_rewards, gittins_reward)
        gittins_priors[gittins_action][gittins_reward ? 1 : 2] += 1        
        
        # thompson sampling
        thompson_action = get_thompson_action(thompson_priors)
        thompson_reward = rand(arms[thompson_action])
        push!(thompson_rewards, thompson_reward)
        thompson_priors[thompson_action][thompson_reward ? 1 : 2] += 1
            
        # UCB 
        ucb_action = get_ucb_action(ucb_values, pull, c)
        ucb_reward = rand(arms[ucb_action])
        push!(ucb_rewards, ucb_reward)
        ucb_values[ucb_action] += (ucb_reward - ucb_values[ucb_action]) / pull # or is this ucb_counts[ucb_action]
            
        # Epsilon Greedy 
        q_counts[gittins_action] += 1
        q_values[gittins_action] += (gittins_reward - q_values[gittins_action]) / q_counts[gittins_action]
        greedy_action = get_greedy_action(q_values, epsilon)
        greedy_reward = rand(arms[greedy_action])
        push!(epsilon_rewards[epsilon_action], epsilon_reward)
            
    end
    return (
        gittins_rewards=gittins_rewards, 
        thompson_rewards=thompson_rewards,
        ucb_rewards=ucb_rewards,
        epsilon_rewards=epsilon_rewards,
    )
end

gittins_ucb_thompson (generic function with 1 method)

In [7]:
function graph_thompson_gittins()
    num_pulls = 500
    num_arms = 10
    gamma = 0.99
    
    results = gittins_vs(
        num_arms = num_arms,
        num_pulls = num_pulls,
        gamma = gamma,
    )
    
    plot(
        1:num_pulls, 
        [
            cumsum(results.gittins_rewards),
            cumsum(results.thompson_rewards),
        ], 
        title="Explore-Exploit Strategies for Multi-Armed Bandits", 
        label=["Gittins Indices" "Thompson Sampling"],
        xlabel="Pulls",
        ylabel="Cumulative Reward",
    )
end

graph_thompson_gittins()

LoadError: UndefVarError: Bernoulli not defined

In [None]:
# MODIFY for ucb
function graph_ucb_gittins()
    num_pulls = 500
    num_arms = 10
    gamma = 0.99
    
    results = gittins_vs(
        num_arms = num_arms,
        num_pulls = num_pulls,
        gamma = gamma,
    )
    
    plot(
        1:num_pulls, 
        [
            cumsum(results.gittins_rewards),
            cumsum(results.thompson_rewards),
        ], 
        title="Explore-Exploit Strategies for Multi-Armed Bandits", 
        label=["Gittins Indices" "Thompson Sampling"],
        xlabel="Pulls",
        ylabel="Cumulative Reward",
    )
end

graph_ucb_gittins()