In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm

In [2]:
class ThompsonSampleGauss:
    """
    A class to perform Gaussian Thompson Sampling on the multi-armed bandit problem.
    """
    
    def __init__(self, bandits: int, means: np.ndarray, stds: np.ndarray, pulls: int) -> None:
        """
        Initializes the ThompsonSampleTest class.

        Args:
            bandits: the number of arms in the bandit problem.
            means: a numpy array representing the true means of the bandits.
            stds: a numpy array representing the true standard deviations of the bandits.
            pulls: the number of times to pull the arms of the bandit problem.
        """
        self.bandits = bandits
        self.prior_means = np.zeros(bandits)
        self.prior_stds = np.ones(bandits) * 1000
        self.actual_means = means
        self.actual_stds = stds
        self.pulls = pulls
        self.samples = np.zeros(bandits)
        self.mean_sequences = {n : [] for n in range(bandits)}
        self.stds_sequences = {n : [] for n in range(bandits)}
        
        for n in range(bandits):
            self.mean_sequences[n].append(0)
        for n in range(bandits):
            self.stds_sequences[n].append(1)
        
    def get_choice(self) -> int:
        """
        Samples the prior distributions to select the arm to pull. 
        Note that we sample it twice to get both a mean and a standard deviation.        
        
        Returns:
            The index of the arm to pull.
        """       
        sample_means = norm.rvs(self.prior_means, self.prior_stds)
        sample_stds = np.abs(norm.rvs(self.prior_means, self.prior_stds))
        
        return np.argmax(norm.rvs(sample_means, sample_stds))
    
    def get_reward(self, choice: int) -> float:        
        """
        Samples the distribution of the arm chosen to get a reward.

        Args:
            choice: The index of the arm chosen.
            
        Returns:
            The reward obtained from the chosen arm.
        """        
        reward = norm.rvs(self.actual_means[choice], self.actual_stds[choice])
        
        return reward
    
    def update_prior(self, choice: int, reward: float) -> None:
        """
        Updates the prior distribution for the chosen arm.

        Args:
            choice: The index of the arm chosen.
            reward: The reward obtained from the chosen arm.
        """       
        self.samples[choice] += 1
        self.prior_means[choice] = (self.prior_means[choice] * (self.samples[choice] - 1) + reward) / self.samples[choice]
        self.prior_stds[choice] = np.sqrt((self.prior_stds[choice]**2 * (self.samples[choice] - 1) + (reward - self.prior_means[choice])**2) / self.samples[choice])
        
        self.mean_sequences[choice].append(self.prior_means[choice])
        self.stds_sequences[choice].append(self.prior_stds[choice])
                                          
    def do_it(self) -> None:
        """
        Performs the Thompson Sampling algorithm for the specified number of pulls.
        """        
        for i in range(self.pulls):

            choice = self.get_choice()
            reward = self.get_reward(choice)
            
            self.update_prior(choice, reward)


In [7]:
# generating random sets of means and standard deviations
means = norm.rvs(100, 25, size=100)
stds = np.abs(norm.rvs(52, 10, size=100))

# instantiating the class
test = ThompsonSampleGauss(len(means), means, stds, 755)

# running the thompson sampling algorithm
test.do_it()

# converged?
print(np.argmax(means) == np.argmax(test.prior_means))

True


In [8]:
# lets see a dataframe 
df = pd.DataFrame({"final_est_mean": test.prior_means,
                   "actual_mean" : means,
                   "final_est_std" : test.prior_stds,
                   "actual_std" : stds,
                   "sample_count" : test.samples})

df.sort_values(by='final_est_mean', ascending=False).head(15)

Unnamed: 0,final_est_mean,actual_mean,final_est_std,actual_std,sample_count
99,175.043607,169.445468,55.090688,52.480768,60.0
40,157.867114,158.917035,38.92891,39.517591,36.0
38,153.973392,144.043059,29.912014,35.776368,44.0
76,150.792522,159.960072,48.132634,52.570348,29.0
67,150.054029,140.539305,57.702591,65.803882,45.0
39,144.686356,141.563414,48.82567,59.201717,43.0
47,139.859702,122.178205,39.294339,47.344398,24.0
66,137.540206,146.242479,30.005434,39.656305,17.0
1,134.16845,131.806453,26.204467,56.405495,10.0
53,132.105108,136.110496,43.366525,50.862118,40.0
