In [None]:
import numpy as np, matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
p1 = np.linspace(0, 1, 100)
p2, p3 = np.meshgrid(p1,p1)

## Inference

In [None]:
def find_leaf(p,p2,p3):
    # 0 or 1
    if p>0.5:
        return (p2>0.5).astype(np.int32)
    # 2 or 3
    return (p3>0.5)+2

In [None]:
from IPython.display import HTML
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import BoundaryNorm

def animate(f, norm=None):
    fig, ax = plt.subplots()
    ax.set_xlabel('p2')
    ax.set_ylabel('p3')
    top = f(0.5,p2,p3)
    im = ax.pcolormesh(p2,p3, top, norm=norm)
    fig.colorbar(im)
    def update(frame):
        p = p1[frame]
        top = f(p,p2,p3)
        im.set_array(top.ravel())
        ax.set_title(f"p1={p:.2f}")
    ani = animation.FuncAnimation(fig=fig, func=update, frames=len(p1), interval=20)
    plt.close("all")
    return ani

In [None]:
levels = MaxNLocator(nbins=4).tick_values(0, 4)
norm = BoundaryNorm(levels, ncolors=plt.get_cmap().N, clip=True)
HTML(animate(find_leaf, norm).to_jshtml())

## Training

In [None]:
def find_leaf_tr(p: float, p2, p3):
    # 0 or 1
    l1 = p*p2
    l2 = p*(1-p2)
    l3 = (1-p)*p3
    l4 = (1-p)*(1-p3)
    l = np.stack([l1,l2,l3,l4])
    probs = np.take_along_axis(l, l.argmax(axis=0)[None], axis=0).squeeze()
    return l.argmax(axis=0)+2*probs

In [None]:
HTML(animate(find_leaf_tr).to_jshtml())