In [1]:
import jax
from jax import numpy as jnp, random

def normal_pdf(mean, cov, x):
    return jax.scipy.stats.multivariate_normal.pdf(x, mean, cov)
# for evaluating one x at many means:
normal_pdf_batchedmean = jax.vmap(normal_pdf, in_axes=(0, None, None))
# for evaluating many xs at a single mean:
normal_pdf_batchedx = jax.vmap(normal_pdf, in_axes=(None, None, 0))

def mog_pdf(cov, weights, kernel_locs, x):
    """Returns the PDF of a mixture of Gaussians model evaluated at x.

    Arguments:
    - `cov`: a covariance matrix which will be used at every point
    - `weights`: the weight of each point in the mixture model
    - `kernel_locs`: the locations of the observed points which define the model
    - `x`: a query point to evaluate
    """
    normalizer = weights.sum()
    weights = weights / normalizer
    per_kernel_densities = normal_pdf_batchedmean(kernel_locs, cov, x)
    weighted_densities = weights * per_kernel_densities
    return weighted_densities.sum(axis=0)
mog_pdf_batchedx = jax.vmap(mog_pdf, in_axes=(None, None, None, 0))

In [2]:
import altair as alt
import pandas as pd
import numpy as np

In [4]:
n_kernels = 4

cov = jnp.eye(1) * 3e-2
weights = jnp.ones((n_kernels,))
kernel_locs = random.uniform(random.PRNGKey(0), shape=(n_kernels, 1), minval=-1, maxval=1)

xs = jnp.expand_dims(jnp.linspace(-1, 1, 100), axis=1)
ys = mog_pdf_batchedx(cov, weights, kernel_locs, xs)
df = pd.DataFrame({'x': np.array(xs)[:, 0], 
                   'y': np.array(ys),
                   'dist': 'mixture'})
for i, kernel in enumerate(kernel_locs):
    ys = normal_pdf_batchedx(kernel, cov, xs)
    single_df = pd.DataFrame({'x': np.array(xs)[:, 0], 
                              'y': np.array(ys),
                              'dist': f'kernel{i}'})
    df = pd.concat([df, single_df])
alt.Chart(df).mark_line().encode(x='x', y='y', color='dist')

In [45]:
kernel_locs

DeviceArray([[ 0.93064284],
             [-0.54968214],
             [ 0.26605988],
             [-0.40723634]], dtype=float32)

In [None]:
n_kernels = 4

std = 1e-1
cov = jnp.eye(1) * std ** 2
weights = jnp.ones((n_kernels,))
kernel_locs = random.uniform(random.PRNGKey(0), shape=(n_kernels, 1), minval=-1, maxval=1)

xs = jnp.expand_dims(jnp.linspace(-1, 1, 100), axis=1)
ys = mog_pdf_batchedx(cov, weights, kernel_locs, xs)
df = pd.DataFrame({'x': np.array(xs)[:, 0], 
                   'y': np.array(ys),
                   'dist': 'mixture'})
for i, kernel in enumerate(kernel_locs):
    ys = normal_pdf_batchedx(kernel, cov, xs)
    single_df = pd.DataFrame({'x': np.array(xs)[:, 0], 
                              'y': np.array(ys),
                              'dist': f'kernel{i}'})
    df = pd.concat([df, single_df])
alt.Chart(df).mark_line().encode(x='x', y='y', color='dist')