# Central Limit Theorem 

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

try:
    from probml_utils import savefig, latexify
except ModuleNotFoundError:
    %pip install git+https://github.com/probml/probml-utils.git
    from probml_utils import savefig, latexify

In [None]:
latexify(width_scale_factor=2, fig_height=2)

In [None]:
def calcMean(keys, N):
    # Generates N random numbers from 1 to 5 and calculates their mean.
    x = jnp.mean(jax.random.beta(keys, 1, 5, [1, N]))
    return x

In [None]:
def plot_convolutionHist(mean_list, N, sampleSize, bins):
    counts, nbinslocation = jnp.histogram(mean_list, bins=20)
    counts = counts / (sampleSize / bins)

    plt.figure()
    plt.title(f"N = {N}")
    plt.bar(nbinslocation[:-1], counts, width=0.02, color="black", align="edge")

    plt.xticks(jnp.linspace(0, 1, 3))
    plt.yticks(jnp.linspace(0, 3, 4))
    plt.xlim(0, 1)
    plt.ylim(0, 3)
    plt.xlabel("$bins$")
    plt.ylabel("$Frequency$")
    sns.despine()
    savefig(f"clt_N_{N}_latexified")

In [None]:
key = jax.random.PRNGKey(1)
keys = jax.random.split(key, num=100000)

In [None]:
sampleSize = 100000
bins = 20
N_array = [1, 5]
for N in N_array:
    means = jax.vmap(calcMean, in_axes=(0, None), out_axes=0)
    out = means(keys, N)
    plot_convolutionHist(out, N, sampleSize, bins)

In [None]:
from ipywidgets import interact


@interact(N_value=(1, 10))
def generate_random(N_value):
    sampleSize = 100000
    bins = 20
    means = jax.vmap(calcMean, in_axes=(0, None), out_axes=0)
    out = means(keys, N_value)
    plot_convolutionHist(out, N_value, sampleSize, bins)