# Scratch

In [1]:
import jax.numpy as np
from jax import grad, jit, vmap, random, lax
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp

import utils
import metrics
import time
import plot

rkey = random.PRNGKey(0)



## test Gaussian

In [2]:
mean = np.array([1, 2])
cov = np.array([[1, 3], [3, 20]])

In [3]:
gauss = metrics.Gaussian(mean, cov)

In [6]:
sample = gauss.sample(shape=(100,))
rsample = random.multivariate_normal(rkey, mean*3, cov/2, shape=(100,))

In [7]:
gauss.compute_metrics(sample)

{'square_errors': DeviceArray([[1.1789589e-03, 3.1130478e-01],
              [3.4243418e-04, 4.5826101e+00],
              [9.0531439e-02, 3.1579507e-04],
              [3.0842879e-01, 2.9914777e-03]], dtype=float32),
 'ksds': [DeviceArray(217.08476, dtype=float32),
  DeviceArray(2.3804615, dtype=float32),
  DeviceArray(0.7086184, dtype=float32)]}

In [8]:
gauss.compute_metrics(rsample)

{'square_errors': DeviceArray([[4.0129423e+00, 1.9340918e+01],
              [5.6845673e+01, 7.7879669e+02],
              [5.7990074e-01, 2.1410324e-05],
              [7.9445569e-03, 3.1765524e-02]], dtype=float32),
 'ksds': [DeviceArray(221.5557, dtype=float32),
  DeviceArray(112.118126, dtype=float32),
  DeviceArray(541.3391, dtype=float32)]}

## test mixture

In [6]:
means = np.array([[-5, -5], [1, 10]])
covs = np.array([0.5, 2, 4])
weights = np.array([1/3, 2/3, 1/3])

mix = metrics.GaussianMixture(means, covs, weights)

In [None]:
# diffs = []
# grid = 8 + np.arange(10)
# grid = 3**grid
# for i in grid:
#     sample = mix.sample(shape=(i,))
#     diffs.append((np.cov(sample, rowvar=False) - mix.cov)**2 / mix.cov)
# diffs = np.array(diffs)
# diffs = diffs.reshape(len(grid),-1)

In [None]:
plt.plot(grid, diffs, ".")
plt.xscale("log")
# plt.yscale("log")

## jax einsum floating point round-off error

In [24]:
import jax.numpy as jnp
import numpy as onp

Jax

In [60]:
values = jnp.array([[-5], [10]])
weights = jnp.array([1/3, 2/3])

In [61]:
jnp.einsum("i,id->d", weights, values)

DeviceArray([5.0000005], dtype=float32)

In [62]:
jnp.sum(values.flatten() * weights)

DeviceArray(5., dtype=float32)

Numpy

In [50]:
values = onp.array([[-5], [10]], dtype=np.float32)
weights = onp.array([1/3, 2/3], dtype=np.float32)

In [51]:
onp.einsum("i,id->d", weights, values)

array([5.], dtype=float32)