# Scratch

In [1]:
import sys
sys.path.append("/home/lauro/code/msc-thesis/svgd")

import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp

import utils
import metrics
import time
import plot
import svgd
import stein

rkey = random.PRNGKey(0)



# test `vmap` row column behaviour

In [2]:
f = lambda x, y: 10*x - y

In [None]:
fv  = vmap(f,  (None, 0))
fvv = vmap(fv, (0, None))

x = np.array([1,2])
f_matrix = fvv(x, x)

# testing `stein.stein`

In [2]:
d = 3
dist = metrics.Gaussian(np.zeros(d), 1)
xs = dist.sample((100,))

### Case 1: 
Input is $f: \mathbb R^d \to \mathbb R^d$.
Then 
$$\mathcal A_p [f] (x) \in \mathbb R^{d \times d}$$
and 
$$\mathcal A_p^T [f] (x) \in \mathbb R$$

In [3]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return x*2 + a
f([3.,2.,1])

DeviceArray([7., 6., 5.], dtype=float32)

In [4]:
stein.stein(f, xs, dist.logpdf).shape

(3, 3)

In [7]:
stein.stein(f, xs, dist.logpdf, transposed=True)

DeviceArray(-0.5033575, dtype=float32)

### Case 2:
Input is $f: \mathbb R^d \to \mathbb R$.
Then 
$$\mathcal A_p [f] (x) \in \mathbb R^{d}.$$

Note that $\mathcal A_p^ T [f]$ makes no sense.

In [9]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return np.sum(x*2 + a)
f([3.,2.,1])

DeviceArray(18., dtype=float32)

In [10]:
stein.stein(f, xs, dist.logpdf).shape

(3,)

In [12]:
# stein.stein(f, xs, dist.logpdf, transposed=True) # throws an error

In [13]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return np.einsum("i,j->ij",2*x, a)
f([3.,2.,1])

DeviceArray([[ 6., 12., 18.],
             [ 4.,  8., 12.],
             [ 2.,  4.,  6.]], dtype=float32)

In [17]:
jacfwd(f)([3.,2.,1.])

[DeviceArray([[2., 4., 6.],
              [0., 0., 0.],
              [0., 0., 0.]], dtype=float32),
 DeviceArray([[0., 0., 0.],
              [2., 4., 6.],
              [0., 0., 0.]], dtype=float32),
 DeviceArray([[0., 0., 0.],
              [0., 0., 0.],
              [2., 4., 6.]], dtype=float32)]

In [24]:
onp.einsum("iii->i", jacfwd(f)([3.,2.,1.]))

array([2., 4., 6.], dtype=float32)

In [30]:
np.einsum("ii", f([3,2,1]))

DeviceArray(20, dtype=int32)

# jit stuff 

In [None]:
def t(a, b):
    return a + b

In [None]:
jacfwd(t)(1., 2.)

In [None]:
@jit
def outer(defval, const=None):
    if const is None:
        const = defval
    
    def loss():
        return const
    return loss()
    

In [None]:
outer(5, None)

In [None]:
outer(5, 15)

In [None]:
outer(10, None)

## test mixture

In [None]:
d = 3
k = 5
rkey = random.split(rkey)[0]
means = random.uniform(rkey, shape=(k, d))
covs = random.uniform(rkey, shape=(k, d, d))
covs = np.einsum("kil,kjl->kij", covs, covs)
weights = np.array([1/3, 2/3, 2/3, 2/3, 1/3])

mix = metrics.GaussianMixture(means, covs, weights)

In [None]:
diffs = []
grid = np.arange(15)
grid = 3**grid
for i in grid:
    sample = mix.sample(shape=(i,))
    diffs.append(np.mean((np.cov(sample, rowvar=False) - mix.cov)**2 / mix.cov))
diffs = np.array(diffs)

In [None]:
plt.plot(grid, diffs, ".")
plt.yscale("log")
plt.xscale("log")

## test Gaussian

In [None]:
mean = np.array([1, 2])
cov = np.array([[1, 3], [3, 20]])

In [None]:
gauss = metrics.Gaussian(mean, cov)

In [None]:
metrics.Distribution

In [None]:
sample = gauss.sample(shape=(100,))
rsample = random.multivariate_normal(rkey, mean*3, cov/2, shape=(100,))

In [None]:
gauss.compute_metrics(sample)

In [None]:
gauss.compute_metrics(rsample)

## jax einsum floating point round-off error

In [None]:
import jax.numpy as jnp
import numpy as onp

Jax

In [None]:
values = jnp.array([[-5], [10]])
weights = jnp.array([1/3, 2/3])

In [None]:
jnp.einsum("i,id->d", weights, values)

In [None]:
jnp.sum(values.flatten() * weights)

Numpy

In [None]:
values = onp.array([[-5], [10]], dtype=np.float32)
weights = onp.array([1/3, 2/3], dtype=np.float32)

In [None]:
onp.einsum("i,id->d", weights, values)