In [44]:
import altair as alt
import pandas as pd
import numpy as np
import math
import scipy.special

def personal():
    return {
        'config': {
            'view': {
                'height': 300,
                'width': 400,
            },
            'range': {
                'category': {'scheme': 'set2'},
                'ordinal': {'scheme': 'set2'},
            },
            'legend': {
                'labelLimit': 0,
            },
            'background': 'white',
            'mark': {
                'clip': True,
            },
            'line': {
                'size': 3,
#                 'opacity': 0.4
            },


        }
    }

def publication():
    colorscheme = 'set2'
    stroke_color = '333'
    title_size = 24
    label_size = 20
    line_width = 5

    return {
        'config': {
            'view': {
                'height': 500,
                'width': 600,
                'strokeWidth': 0,
                'background': 'white',
            },
            'title': {
                'fontSize': title_size,
            },
            'range': {
                'category': {'scheme': colorscheme},
                'ordinal': {'scheme': colorscheme},
            },
            'axis': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'grid': False,
                'domainWidth': 5,
                'domainColor': stroke_color,
                'tickWidth': 3,
                'tickSize': 9,
                'tickCount': 4,
                'tickColor': stroke_color,
                'tickOffset': 0,
            },
            'legend': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'labelLimit': 0,
                'titleLimit': 0,
                'orient': 'top-left',
#                 'padding': 10,
                'titlePadding': 10,
#                 'rowPadding': 5,
                'fillColor': '#ffffff88',
#                 'strokeColor': 'black',
                'cornerRadius': 0,
            },
            'rule': {
                'size': 3,
                'color': '999',
                # 'strokeDash': [4, 4],
            },
            'line': {
                'size': line_width,
#                 'opacity': 0.4
            },
        }
    }

alt.themes.register('personal', personal)
alt.themes.register('publication', publication)
alt.themes.enable('personal')

ThemeRegistry.enable('personal')

In [1]:
import jax
from jax import numpy as jnp, random

def normal_pdf(mean, cov, x):
    return jax.scipy.stats.multivariate_normal.pdf(x, mean, cov)
# for evaluating one x at many means:
normal_pdf_batchedmean = jax.vmap(normal_pdf, in_axes=(0, None, None))
# for evaluating many xs at a single mean:
normal_pdf_batchedx = jax.vmap(normal_pdf, in_axes=(None, None, 0))

def mog_pdf(cov, weights, kernel_locs, x):
    """Returns the PDF of a mixture of Gaussians model evaluated at x.

    Arguments:
    - `cov`: a covariance matrix which will be used at every point
    - `weights`: the weight of each point in the mixture model
    - `kernel_locs`: the locations of the observed points which define the model
    - `x`: a query point to evaluate
    """
    normalizer = weights.sum()
    weights = weights / normalizer
    per_kernel_densities = normal_pdf_batchedmean(kernel_locs, cov, x)
    weighted_densities = weights * per_kernel_densities
    return weighted_densities.sum(axis=0)
mog_pdf_batchedx = jax.vmap(mog_pdf, in_axes=(None, None, None, 0))

def normal_pdf_1d(sigma, x):
    return jax.scipy.stats.norm.pdf(x, scale=sigma)

In [46]:
alt.themes.enable('publication')
n_kernels = 4

sigma = 1e-1
cov = jnp.eye(1) * sigma**2
weights = jnp.ones((n_kernels,))
kernel_locs = random.uniform(random.PRNGKey(0), shape=(n_kernels, 1), minval=-1, maxval=1)

xs = jnp.expand_dims(jnp.linspace(-1, 1, 100), axis=1)
# ys = mog_pdf_batchedx(cov, weights, kernel_locs, xs) 
# df = pd.DataFrame({'x': np.array(xs)[:, 0], 
#                    'y': np.array(ys),
#                    'dist': 'sum'})
df = pd.DataFrame()
ys_sum = np.zeros(100)
for i, kernel in enumerate(kernel_locs):
    ys = normal_pdf_batchedx(kernel, cov, xs) / 4
    ys_sum += ys
    single_df = pd.DataFrame({'location': np.array(xs)[:, 0], 
                              'count': np.array(ys),
                              'source': f'point {i}'})
    df = pd.concat([df, single_df])
sum_df = pd.DataFrame({'location': np.array(xs)[:, 0],
                       'count': np.array(ys_sum),
                       'source': 'sum'})
# df = pd.concat([df, sum_df])
kernels = alt.Chart(df).mark_line(opacity=1).encode(
    x='location', y='count', 
    color=alt.Color('source', legend=alt.Legend(orient='top-right')))
total = alt.Chart(sum_df).mark_line(opacity=0.3, size=10).encode(
    x='location', y='count', 
    color=alt.Color('source', legend=alt.Legend(orient='top-right')))
kernels + total

In [25]:
import torch
t_kernel_locs = torch.tensor(np.array(kernel_locs))
t_xs = torch.tensor(np.array(xs))

# Turn our Tensors into KeOps symbolic variables:
from pykeops.torch import LazyTensor
x_i = LazyTensor( t_kernel_locs[:,None,:] )  # x_i.shape = (1e6, 1, 3)
y_j = LazyTensor( t_xs[None,:,:] )  # y_j.shape = ( 1, 2e6,3)

# We can now perform large-scale computations, without memory overflows:
D_ij = (((x_i - y_j) / sigma)**2).sum(dim=2)
K_ij = (- 0.5 * D_ij).exp()

In [26]:
keops_preds = K_ij.sum(0).reshape((-1)) / n_kernels * ((2 * math.pi) ** (-0.5)) * (sigma ** -1)
keops_df = pd.DataFrame({'x': np.array(xs)[:, 0], 
                         'y': np.array(keops_preds),
                         'dist': 'keops'})
keops_df = pd.concat([df, keops_df])
alt.Chart(keops_df).mark_line().encode(x='x', y='y', color='dist')

Compiling libKeOpstorch1ef0d2250e in /home/will/.cache/pykeops-1.4.2-cpython-37-gpu:
       formula: Sum_Reduction(Exp((Var(3,1,2) * Sum(Square(((Var(0,1,0) - Var(1,1,1)) / Var(2,1,2)))))),1)
       aliases: Var(0,1,0); Var(1,1,1); Var(2,1,2); Var(3,1,2); 
       dtype  : float32
... Done.


In [27]:
from densities import keops_kernel_count as density



In [None]:
density_state = density.new()