In [1]:
using Distributions
import Base.rand
import Random
using Formatting

Random.seed!(42)

# See:
# Daniel J. Russo et al, "A Tutorial on Thompson Sampling."
# https://web.stanford.edu/~bvr/pubs/TS_Tutorial.pdf

MersenneTwister(42)

In [2]:
mutable struct ThompsonSampling
    arms::Vector{Distribution}
    distr::Vector{LogNormal}
    τ::Real
    
    ThompsonSampling(arms) = new(arms, [LogNormal() for _ in arms], 1)
end

function rand(sampler::ThompsonSampling)
    return map(distr -> rand(distr), sampler.distr)
end

function posterior(distr::LogNormal, y::Real, τ::Real)
    μ_prior = mean(distr)
    σ_prior = var(distr)
    
    σ_post = 1 / (1/σ_prior + 1/τ)
    μ_post = μ_prior/σ_prior + log(y + τ/2)/τ
    μ_post *= σ_post
    
    return LogNormal(μ_post, σ_post)
end

function step!(sampler::ThompsonSampling) 
    idx = argmax(rand(sampler))
    reward = rand(sampler.arms[idx])
        
    sampler.distr[idx] = posterior(sampler.distr[idx], reward, sampler.τ)
    
    return reward
end

step! (generic function with 1 method)

In [3]:
function experiment(iters, trials = 5000)
    
    total_rewards = zeros(3)

    for _ in 1:trials
        arms = [Poisson(rand(Uniform(1, 10))) for _ in 1:10]

        # pure exploration
        for _ in 1:iters
            idx = rand(1:length(arms))
            total_rewards[1] += rand(arms[idx])
        end

        # explore first
        rewards = zeros(length(arms))
        for i in 1:iters
            if i < iters // 2
                idx = rand(1:length(arms))
            else
                idx = argmax(rewards)
            end

            rewards[idx] += rand(arms[idx])
        end
        total_rewards[2] += sum(rewards)

        # thompson sampling
        sampler = ThompsonSampling(arms)
        for _ in 1:iters
            total_rewards[3] += step!(sampler)
        end    
    end

    printfmt("\n\niterations:        {:d}\n", iters)
    printfmt("explore only:      {:f}\nexplore first:     {:f}\nthompson sampling: {:f}", (total_rewards/iters/trials)...)
    
end

experiment (generic function with 2 methods)

In [4]:
experiment(10)
experiment(100)
experiment(1000)



iterations:        10
explore only:      5.513920
explore first:     6.524320
thompson sampling: 6.231600

iterations:        100
explore only:      5.501280
explore first:     6.842702
thompson sampling: 6.990698

iterations:        1000
explore only:      5.508589
explore first:     7.158247
thompson sampling: 7.425364