## 构建 Bernoulli Distribution 的 Multi-armed Bandit

In [1]:
import numpy as np
import matplotlib.pyplot as plt

定义 Multi-armed Bandit

In [4]:
class BernoulliBandit:
    def __init__(self, K):
        self.probs = np.random.rand(K)
        self.best_idx = np.argmax(self.probs)
        self.best_prob = np.max(self.probs)
        self.K = K
    
    def step(self, k:int):
        if np.random.rand() < self.probs[k]:
            return 1
        else:
            return 0



In [5]:
np.random.seed(1)

bernoulli_bandit = BernoulliBandit(10)
print("The distribution of bernoulli_bandit: ",bernoulli_bandit.probs)

print("The best arm index: ",bernoulli_bandit.best_idx)
print("The best arm probability: ",bernoulli_bandit.best_prob)

print("The probability of arm 3: ",bernoulli_bandit.probs[3])
x = bernoulli_bandit.step(3)
print("The result of arm 3: ",x)

The distribution of bernoulli_bandit:  [4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01
 1.46755891e-01 9.23385948e-02 1.86260211e-01 3.45560727e-01
 3.96767474e-01 5.38816734e-01]
The best arm index:  1
The best arm probability:  0.7203244934421581
The probability of arm 3:  0.30233257263183977
The result of arm 3:  0


定义一个 Solver 类来解决 Bandit 问题

In [6]:
class Solver:
    def __init__(self, bandit:BernoulliBandit):
        self.bandit = bandit
        self.counts = np.zeros(self.bandit.K)
        self.regret = 0
        self.actions = []   # 每一步的动作
        self.regrets = []   # 每一步的累积懊悔
    
    def update_regret(self, k):
        self.regret += self.bandit.best_prob - self.bandit.probs[k]
        self.regrets.append(self.regret)

    def run_one_step(self):
        raise NotImplementedError
    
    def run(self, num_steps):
        for _ in range(num_steps):
            k = self.run_one_step()
            self.actions.append(k)
            self.counts[k] += 1
            self.update_regret(k)
            
