# Dirichlet distribution

In [None]:
# Common imports.
from pathlib import Path

%matplotlib widget
import base

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import ipywidgets

import dfaas_env
import dfaas_utils

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri

corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
AREA = 0.5 * 1 * 0.75**0.5
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=4)

plt.figure(figsize=(8, 4))
for i, mesh in enumerate((triangle, trimesh)):
    plt.subplot(1, 2, i + 1)
    plt.triplot(mesh)
    plt.axis("off")
    plt.axis("equal")

# For each corner of the triangle, the pair of other corners
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
# The area of the triangle formed by point xy and another pair or points
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))


def xy2bc(xy, tol=1.0e-4):
    """Converts 2D Cartesian coordinates to barycentric."""
    coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
    return np.clip(coords, tol, 1.0 - tol)


class Dirichlet(object):
    def __init__(self, alpha):
        from math import gamma
        from operator import mul

        self._alpha = np.array(alpha)
        self._coef = gamma(np.sum(self._alpha)) / np.multiply.reduce([gamma(a) for a in self._alpha])

    def pdf(self, x):
        """Returns pdf value for `x`."""
        from operator import mul

        return self._coef * np.multiply.reduce([xx ** (aa - 1) for (xx, aa) in zip(x, self._alpha)])


def draw_pdf_contours(dist, nlevels=200, subdiv=8, **kwargs):
    import math

    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]

    plt.tricontourf(trimesh, pvals, nlevels, cmap="jet", **kwargs)
    plt.axis("equal")
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis("off")

    return plt.gcf()

node_1

action_dist = [0.54363894 0.41351184 0.04284927]

{'vf_preds': np.float32(44.169827), 'action_dist_inputs': array([4.2691507, 3.6986694, 1.8697257], dtype=float32), 'action_prob': np.float32(52.773197), 'action_logp': np.float32(3.9660034)}

In [None]:
# local, forward, reject
fig = draw_pdf_contours(Dirichlet([5, 2.5, 1]))

In [None]:
fig.savefig("dirichlet.svg")