# making / fixing stuff

In [1]:
import sys
sys.path.append("/home/lauro/code/msc-thesis/svgd")

import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp
from tqdm import tqdm

import utils
import metrics
import time
import plot
from svgd import SVGD
import svgd
import stein

rkey = random.PRNGKey(0)



# Phistar

In [2]:
def phistar(xs, logp, logh):
    def f(x, y):
        """evaluated inside the expectation"""
        kx = lambda y: utils.ard(x, y, logh)
        return stein.stein_operator(kx, y, logp, transposed=False)
    
    fv  = vmap(f,  (None, 0))
    fvv = vmap(fv, (0, None))
    phi_matrix = fvv(xs, xs)
    
    n = xs.shape[0]
    trace_indices = [list(range(n))]*2
    phi_matrix = index_update(phi_matrix, trace_indices, 0)
    
    return np.mean(phi_matrix, axis=1)
#     return phi_matrix
phistar = jit(phistar, static_argnums=1)

In [3]:
h = 1**2
n = 5
dist = metrics.Gaussian(0,1)
xs = dist.sample((n, 1))
print("phistar without diagonal", stein.phistar(xs, dist.logpdf, np.log(h)))
print("current with diagonal:", svgd.phistar(xs, dist.logpdf, np.log(h)))

phistar without diagonal [[-0.34314248]
 [-0.5031638 ]
 [ 0.10763108]
 [-0.5713571 ]
 [-0.13896786]]
current with diagonal: [[-0.4629477 ]
 [-0.54660666]
 [-0.3764267 ]
 [-0.5778905 ]
 [-0.38226676]]


# New KSD

Write KSD squared as
$$\text{KSD}(q \ \Vert \ p) = E_{Z, Z'}[g_p(Z, Z')]$$
where $Z$ and $Z'$ are independently distributed as $q$, and $g$ is defined as

$$\begin{aligned}
g_{p}(x, y):=& \nabla \log p(x)^{\top} \nabla \log p(y) k(x, y) \\
&+\nabla \log p(y)^{\top} \nabla_{x} k(x, y) \\
&+\nabla \log p(x)^{\top} \nabla_{y} k(x, y) \\
&+\left\langle\nabla_{x} k(x, \cdot), \nabla_{y} k(\cdot, y)\right\rangle_{\mathcal{F}^{d}}
\end{aligned}$$

The last term can also be written as
$$\sum_{i=1}^{d} \frac{\partial k(x, y)}{\partial x_{i} \partial y_{i}}.$$

Note that $g$ is equal to
$$g_p(x, y) = \mathcal A_p^T \mathcal A^{(y)}_p k(x, y).$$

In [4]:
# code moved to stein.py

time:

In [5]:
hgrid = np.logspace(-1, 3, num=100)
h = 1.5
dist = metrics.Gaussian(0,1)
x = dist.sample((100,1))
logp = dist.logpdf



In [6]:
%timeit np.array([metrics.ksd_squared(x, logp, np.log(h)) for h in hgrid]).block_until_ready

23.5 ms ± 326 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%timeit np.array(  [stein.ksd_squared(x, logp, np.log(h)) for h in hgrid]).block_until_ready

29.4 ms ± 5.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
stein.ksd_squared(x, logp, np.log(h))

DeviceArray(-0.01484735, dtype=float32)

In [9]:
metrics.ksd_squared(x, logp, np.log(h))

DeviceArray(-0.01484735, dtype=float32)

In [10]:
slkfj

NameError: name 'slkfj' is not defined

# making new ard

In [None]:
x = np.array([1,2])
A = np.array([[1, 2], [3, 4]])

In [None]:
np.matmul(A, x)

In [None]:
utils.ard_m(x, y, A)

In [None]:
utils.ard_m(1, 2, 3)

# fixing phistar

In [None]:
dist = metrics.Gaussian(0, 1)

In [None]:
z = np.array([1.])
x = np.array([2.])
mu = 1
n = 100

def k(y):
    return np.exp(-1/2 * (y - z)**2)


print(stein.stein_operator(k, x, dist.logpdf))
print(k(x) * (z - 2*x))

In [None]:
def phis(x):
    return - 1 / np.sqrt(2) * mu * np.exp(- 1/4 * (z - mu)**2)

In [None]:
xs = dist.sample(shape=(n, 1)) + mu

print(svgd.phistar_i(z, xs, dist.logpdf, 1))
print(svgd._phistar_i(z, xs, dist.logpdf, 1))
print(phis(z))

In [None]:
salkdfj

# still fixing `stein.stein`

In [None]:
dist = metrics.Gaussian(0, 1)

In [None]:
def fun(x):
    return np.exp(-x**2)

n = 100
steins = []
mugrid = np.linspace(-6, 6, num=50)

for mean in mugrid:
    sam = dist.sample(shape=(n, 1)) + mean
    steins.append(stein.stein(fun, sam, dist.logpdf))

In [None]:
plt.plot(mugrid, steins)

Meanwhile, we know that for $f(x) = e^{-x^2}$ and $q = \mathcal N(\mu, 1)$,

$$
E_{x \sim q}[ \mathcal A_p [f] (x)] = - \frac{\mu}{\sqrt 3} \cdot e^{- \mu^2 / 3}
$$


In [None]:
def stein_true(mu):
    return mu / np.sqrt(3) * np.exp(- mu**2 / 3)

In [None]:
n = 100
tsteins = []
for mean in mugrid:
    tsteins.append(stein_true(mean))

In [None]:
plt.plot(mugrid, tsteins)

In [None]:
n = 100
steins = []
sigmagrid = np.linspace(-6, 6, num=50)

for sigma in sigmagrid:
    sam = dist.sample(shape=(n, 1)) * sigma
    steins.append(stein.stein(fun, sam, dist.logpdf))

This is supposed to be constant = 0. Checks out.

In [None]:
plt.plot(sigmagrid, steins, "r.")

# fix `stein.stein`

In [None]:
dist = metrics.Gaussian(0,1)
ngrid = np.arange(10, 1500, 50)

def fun(x):
    return np.exp(-x**2)

steins = []
for n in tqdm(ngrid):
    xs = dist.sample(shape=(n,1))
    steins.append(stein.stein(fun, xs, dist.logpdf))

In [None]:
steins = np.array(steins)

In [None]:
np.argmin(steins)

In [None]:
plt.plot(ngrid, -steins, ".")
# plt.yscale("log")

# `phistar_i`

In [None]:
asdjfh

In [None]:
n = 200

dist = metrics.Gaussian([0, 0], [1, 1])
logp = dist.logpdf
bandwidth = 1

In [None]:
ngrid = np.arange(10, 1000, 30)
phis = []
x = random.normal(rkey, shape=(2,))
for n in tqdm(ngrid):
    rkey = random.split(rkey)[0]
    xs = random.normal(rkey, shape=(n,2))
    phis.append(svgd._phistar_i(x, xs, logp, bandwidth))

In [None]:
phis = np.array(phis)

In [None]:
plt.plot(ngrid, np.abs(phis), ".")
# plt.yscale("log")

In [None]:
svgd._phistar_i(xi, xs, logp, bandwidth) * n

In [None]:
dist = metrics.Gaussian([1, 2], [1, 2])
logp = dist.logpdf
bandwidth = 1.

In [None]:
n = 20

In [None]:
sample = dist.sample(shape=(n,))
sample.shape

In [None]:
svgd.phistar_i(x, sample, logp, bandwidth)

In [None]:
svgd._phistar_i(x, sample, logp, bandwidth)

# test both ksds

In [None]:
dist = metrics.Gaussian([1, 2], [1, 2])
logp = dist.logpdf
bandwidth = 1.

In [None]:
n = 27

In [None]:
sample = dist.sample(shape=(n,))
sample.shape

In [None]:
metrics.ksd(sample, logp, bandwidth)

In [None]:
metrics._ksd(sample, logp, bandwidth)

In [None]:
metrics._ksd(sample, logp, bandwidth) * n

In [None]:
sfjskjaf

# KSD thing

### short range

In [None]:
dist = metrics.Gaussian(0, 1)
ngrid = np.arange(2, 100, 3)
ksds = []
mses = []
for n in tqdm(ngrid):
    x = dist.sample(shape=(n,))
    ksds.append(metrics._ksd(x, dist.logpdf, 1))
    mses.append(np.mean(x**2 - dist.expectations[1])**2)

In [None]:
plt.plot(ngrid, ksds, "r.")
plt.yscale("log")

In [None]:
plt.plot(ngrid, mses, ".")
plt.yscale("log")

### long range

In [None]:
dist = metrics.Gaussian(0, 1)
ngrid = np.arange(5, 5000, 100)
ksds = []
mses = []
for n in tqdm(ngrid):
    x = dist.sample(shape=(n,))
    ksds.append(metrics._ksd(x, dist.logpdf, 1))
    mses.append(np.mean(x**2 - dist.expectations[1])**2)

In [None]:
plt.plot(ngrid, ksds, "r.")
plt.yscale("log")

In [None]:
plt.plot(ngrid, mses, ".")
plt.yscale("log")

## other samples

In [None]:
rkey = random.split(rkey)[0]
s = random.normal(rkey, shape=(100, 1))

In [None]:
metrics.ksd(s * 2, dist.logpdf, 1)

# other stuff

In [None]:
n = 10
dist13 = metrics.GaussianMixture([-2, 2], [7, 7], [1/3, 2/3])
dist13.compute_metrics_for_sample(10)



s = dist13.sample(shape=(n,1))
# s = s[:, np.newaxis]

In [None]:
dist14 = metrics.Gaussian(0, 1)
dist14.compute_metrics_for_sample(10)

In [None]:
dist13.compute_metrics_for_sample(10)

In [None]:
dist13.expectations

In [None]:
dist13.compute_metrics(s) # throws ValueError

In [None]:
s.reshape((n, 1))

In [None]:
x = s
logp = dist13.logpdf
bandwidth = 1.

i = 0
j = 5
xi = x[i]
xj = x[j]


In [None]:
def phistar_j(x, y, logp, bandwidth):
    """Individual summand needed to compute phi^*. That is, phistar_i = \sum_j phistar_j(xj, xi, logp, bandwidth).
    Arguments:
    * x: np.array of shape ()
    Returns:
    """
    kernel = lambda x, y: utils.ard(x, y, bandwidth)
    return grad(logp)(x) * kernel(x, y) + grad(kernel)(x, y)


# stuff


In [None]:
c = {"bldk": 3, "23": 8}
f = {"bldk": 2, "23": 4}

d = {"one":1, "two": 2, "three": c}
e = {"one":3, "two": 4, "three": f}

g = {"one":3, "test": "a", "two": 4, "three": f}



In [None]:
for k, l ro in zip(d, g):
    print(k)
    print(l)
    print()

In [None]:
utils.dict_divide(d, e)