In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as st
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from matplotlib import cm

In [None]:
edges = np.linspace(0, 1.0, num=20)
xs = np.random.uniform(0, 1.0, 1000)
x = 0.6
y_shift = 0.1

y_count = np.bincount(np.digitize(xs, edges, right=True))
bin_num = np.digitize(x, edges)
y_val = (y_count[bin_num] + 1)*y_shift

In [None]:
y_count.shape, edges.shape

In [None]:
np.digitize(xs, edges).max()

In [None]:
def get_new_points(xs, x, y_shift=0.05):
    edges = np.linspace(0, 1.0, num=20)
    y_count = np.bincount(np.digitize(xs, edges, right=True))
    bin_num = np.digitize(x, edges)
    y_val = (y_count[bin_num - 1] + 1)*y_shift
    
    return x, y_val


class SingleValue:
    def __init__(self, value):
        self.value = value
    
    def pdf(self, x):
        return np.where(np.isclose(x, self.value), 100, 0)
    
    def rvs(self, x=None):
        if x is None:
            return self.value
        else:
            return np.repeat(self.value, x)

class UpdateDist:
    def __init__(self, ax, ax2, probs = [0.5, 0.3, 0.7], counts = [10, 30, 5]):
        # actuals
        self.probs = [st.beta(count * prob, count - (count * prob)) for prob, count in zip(probs, counts)]
        #self.probs = [SingleValue(prob) for prob in probs]
        
        # beta dist
        self.beta_dists = [st.beta(1, 1) for _ in probs]
        
        # the three distributions
        self.lines = [ax.plot([],[], color=cm.Set1(i))[0] for i in range(len(probs))]
        self.points = [ax.plot([],[], "o", ms=5, color=cm.Set1(i), alpha = 0.5, mec = cm.Set1(i), mew=1.0)[0] for i in range(len(probs)) ]
        self.x = np.linspace(0, 1, 101)
        self.ax = ax
        self.ax2 = ax2
        
        self.ax2.set_yticks([])
        # bar charts for summary
        self.bar = self.ax2.barh(np.arange(len(probs)), np.zeros(len(probs)), align='center', color=[cm.Set1(i) for i in range(len(probs))])
        
        # Set up plot parameters
        self.ax.set_xlim(0, 1)
        self.ax2.set_xlim(0, 500)
        self.ax.set_ylim(0, 10)
        self.base_x = []
        
        for i, prob in enumerate(self.probs):
            self.ax.fill_between(self.x, prob.pdf(self.x), alpha = 0.1, color=cm.Set1(i))
            self.ax.plot(self.x, prob.pdf(self.x), linestyle=':', lw = 0.5, color=cm.Set1(i))
            #self.ax.axvline(prob, linestyle='--', color=cm.Set1(i))

    def __call__(self, i):
        # This way the plot can continuously run and we just keep
        # watching new realizations of the process
        if i == 0:
            self.successes = np.zeros_like(self.probs)
            self.trials = np.full_like(self.probs, 2)
            
            [line.set_data([], []) for line in self.lines] 
            return self.lines

        # Sample from each
        sample_probs = [beta.rvs() for beta in self.beta_dists]
        
        arg_max = np.argmax(sample_probs)
        draw = self.probs[arg_max].rvs()
        success = np.random.uniform() < draw
        self.successes[arg_max] += success
        self.trials[arg_max] += 1
        self.beta_dists[arg_max] = st.beta(1 + self.successes[arg_max], self.trials[arg_max] - self.successes[arg_max])
        self.lines[arg_max].set_data(self.x, self.beta_dists[arg_max].pdf(self.x))
        self.lines[arg_max].set_linewidth(4)
        [line.set_linewidth(1) for i, line in enumerate(self.lines) if i != arg_max]
        self.base_x.append(draw)
        x, y = get_new_points(self.base_x, draw)
        xs, ys = self.points[arg_max].get_data()
        
        [b.set_width(c) for b, c in zip(self.bar, self.trials)]
        self.points[arg_max].set_data(np.append(xs,x), np.append(ys, y))
        
        return self.lines
    
# Fixing random state for reproducibility
np.random.seed(2121)


fig = plt.figure(figsize=(10.6667,6), constrained_layout=True)
spec = fig.add_gridspec(10, 1)
ax = fig.add_subplot(spec[:-2, :])
ax2 = fig.add_subplot(spec[-2:, :])

#fig, ax = plt.subplots(figsize=(10,7))
sns.despine(left=True)
ax.yaxis.set_visible(False)
ud = UpdateDist(ax, ax2, probs = [0.5, 0.5, 0.5], counts = [10, 30, 2])

ax.set_ylim(0, 5)
anim = FuncAnimation(fig, ud, frames=500, interval=100, blit=True);

plt.savefig("../../../sidravi1.github.io/assets/20200819_three_distributions.png")

In [None]:
i.set_w