In [1]:
import numpy as np
import pandas as pd
from empiricaldist import Pmf
from scipy.stats import binom

def make_binomial(n, p):
    """Make a binomial distribution.
    
    n: number of trials
    p: probability of success
    
    returns: Pmf representing the distribution of k
    """
    ks = np.arange(n+1)
    ps = binom.pmf(ks, n, p)
    return Pmf(ps, ks)



In [2]:
import janitor
import pyjviz

n = 35
num_sensitive = 10
num_insensitive = n - num_sensitive

dist_sensitive = make_binomial(num_sensitive, 0.95)
dist_insensitive = make_binomial(num_insensitive, 0.40)
dist_total = Pmf.add_dist(dist_sensitive, dist_insensitive)
    
with pyjviz.CB(f"{num_sensitive} out of {n}") as b:
    X = dist_sensitive.node_text("X ~ <i>Bin</i>(10, 0.9)").make_plot()
    Y = dist_insensitive.node_text("Y ~ <i>Bin</i>(25, 0.4)").make_plot()
    Z = dist_total.node_text("Z = X + Y").make_plot()
    XYZ = pd.concat([X, Y, Z], axis = 1).set_columns(['X', 'Y', 'Z=X+Y'])

    XYZ.node_text("X, Y and Z").make_plot()
    t = pyjviz.Text('add_dist')
    pyjviz.Arrow(dist_sensitive, t)
    pyjviz.Arrow(dist_insensitive, t)
    pyjviz.Arrow(t, dist_total)
    
pyjviz.show(vertical = True)

In [3]:
import pandas as pd
with pyjviz.CB():
    table = pd.DataFrame()
    for num_sensitive in range(0, n+1):
        num_insensitive = n - num_sensitive
        dist_sensitive = make_binomial(num_sensitive, 0.95)
        dist_insensitive = make_binomial(num_insensitive, 0.4)
        dist_total = Pmf.add_dist(dist_sensitive, dist_insensitive)    
        table[num_sensitive] = dist_total
        
    few_likelihoods = table[[0,10,20,30]].make_plot().node_text("likelyhood")

    hypos = np.arange(n+1)
    prior = Pmf(1, hypos).make_plot().node_text("prior")
    posteriors = pd.DataFrame()
    #for num_sensitive in [12, 20]:
    for num_sensitive in [0, 10, 20, 30]:
        likelihood = table.loc[num_sensitive]
        posterior = prior * likelihood
        posterior.normalize()
        posteriors[num_sensitive] = posterior

    few_posteriors = posteriors.make_plot().node_text("posterior")

    update_t = pyjviz.Text('update')
    pyjviz.Arrow(prior, update_t)
    pyjviz.Arrow(few_likelihoods, update_t)
    pyjviz.Arrow(update_t, few_posteriors)
    
pyjviz.show(vertical = False)