In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm, uniform, bernoulli, t, beta

In [165]:
class ThompsonSampleGauss:
    """
    A class to perform Gaussian Thompson Sampling on the multi-armed bandit problem.
    """
    
    def __init__(self, bandits: int, means: np.ndarray, stds: np.ndarray, pulls: int) -> None:
        """
        Initializes the ThompsonSampleTest class.

        Args:
            bandits: the number of arms in the bandit problem.
            means: a numpy array representing the true means of the bandits.
            stds: a numpy array representing the true standard deviations of the bandits.
            pulls: the number of times to pull the arms of the bandit problem.
        """
        self.bandits = bandits
        self.prior_means = np.zeros(bandits)
        self.prior_stds = np.ones(bandits) * 1000
        self.actual_means = means
        self.actual_stds = stds
        self.pulls = pulls
        self.samples = np.zeros(bandits)
        self.mean_sequences = {n : [] for n in range(bandits)}
        self.stds_sequences = {n : [] for n in range(bandits)}
        
        for n in range(bandits):
            self.mean_sequences[n].append(0)
        for n in range(bandits):
            self.stds_sequences[n].append(1)
        
    def get_choice(self) -> int:
        """
        Samples the prior distributions to select the arm to pull. 
        Note that we sample it twice to get both a mean and a standard deviation.        
        
        Returns:
            The index of the arm to pull.
        """       
        sample_means = norm.rvs(self.prior_means, self.prior_stds)
        sample_stds = np.abs(norm.rvs(self.prior_means, self.prior_stds))
        
        return np.argmax(norm.rvs(sample_means, sample_stds))
    
    def get_reward(self, choice: int) -> float:        
        """
        Samples the distribution of the arm chosen to get a reward.

        Args:
            choice: The index of the arm chosen.
            
        Returns:
            The reward obtained from the chosen arm.
        """        
        reward = norm.rvs(self.actual_means[choice], self.actual_stds[choice])
        
        return reward
    
    def update_prior(self, choice: int, reward: float) -> None:
        """
        Updates the prior distribution for the chosen arm.

        Args:
            choice: The index of the arm chosen.
            reward: The reward obtained from the chosen arm.
        """       
        self.samples[choice] += 1
        self.prior_means[choice] = (self.prior_means[choice] * (self.samples[choice] - 1) + reward) / self.samples[choice]
        self.prior_stds[choice] = np.sqrt((self.prior_stds[choice]**2 * (self.samples[choice] - 1) + (reward - self.prior_means[choice])**2) / self.samples[choice])
        
        self.mean_sequences[choice].append(self.prior_means[choice])
        self.stds_sequences[choice].append(self.prior_stds[choice])
                                          
    def do_it(self) -> None:
        """
        Performs the Thompson Sampling algorithm for the specified number of pulls.
        """        
        for i in range(self.pulls):

            choice = self.get_choice()
            reward = self.get_reward(choice)
            
            self.update_prior(choice, reward)


In [208]:
# generating random sets of means and standard deviations

means = norm.rvs(100, 25, size=100)
stds = np.abs(norm.rvs(4, 10, size=100))

# instantiating the class
test = ThompsonSampleGauss(len(means), means, stds, 352)

# running the thompson sampling algorithm
test.do_it()

# converged?
print(np.argmax(means) == np.argmax(test.prior_means))

True


In [218]:
# lets see a dataframe 
df = pd.DataFrame({"final_est_mean": test.prior_means,
                   "actual_mean" : means,
                   "final_est_std" : test.prior_stds,
                   "actual_std" : stds,
                   "sample_count" : test.samples})

df.sort_values(by='final_est_mean', ascending=False).head(15)

Unnamed: 0,final_est_mean,actual_mean,final_est_std,actual_std,sample_count
27,152.866228,155.021226,14.893715,20.90612,11.0
51,150.170128,148.617941,10.468304,8.635383,12.0
43,146.48569,147.338679,9.031626,9.752092,14.0
26,143.750853,143.151884,7.964168,6.627543,14.0
41,141.43092,141.955271,1.324675,1.633063,14.0
86,136.384503,140.902578,5.881466,7.258287,8.0
47,133.474048,131.263009,8.361029,9.68878,11.0
50,133.324635,129.493174,5.883575,9.438646,5.0
39,132.39074,137.362961,12.439725,12.689164,9.0
25,130.162834,128.633063,2.78511,3.549396,9.0
