# Scratch

In [2]:
import sys
import os
sys.path.append("/home/lauro/code/msc-thesis/svgd")

import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp

import utils
import metrics
import time
import plot
import svgd
import stein

rkey = random.PRNGKey(0)

In [None]:
ksd_squared = vmap(self.ksd_squared, (None, None, 0)) # operate on batch of particles
hypernetwork = self.hypernetwork
def current_step(i, j, steps):
    return i*steps + j
key1, key2 = random.split(key)

particles = init_svgd(key1, self.particle_shape)
opt_svgd_state = self.opt.init(particles)

particles = np.expand_dims(particles, 0) # batch dimension
params = hypernetwork.init(key2, particles)
opt_ksd_state = opt_ksd.init(params)

log = dict()
for i in range(n_iter):
    # update particles:
    params = opt_ksd.get_params(opt_ksd_state)
    particle_batch = []
    for j in range(svgd_steps):
        step = current_step(i, j, svgd_steps)

        particles = self.opt.get_params(opt_svgd_state)
        gp = -self.phistar(particles, params) # TODO gradient has wrong extra batch dim
        print(gp.shape)
        opt_svgd_state = self.opt.update(step, gp, opt_svgd_state)

        particle_batch.append(particles)
        utils.warn_if_nan(gp)

    # update network params:
    particle_batch = np.asarray(particle_batch, dtype=np.float32)
    log = metrics.append_to_log(log, {"particles": particle_batch})
    inner_updates = []
    ksds = []
    gradients = []
    for j in range(ksd_steps):
        step = current_step(i, j, ksd_steps)
        params = opt_ksd.get_params(opt_ksd_state)
        ksd, gk = value_and_grad(ksd_squared)(params, particle_batch)
        opt_ksd_state = opt_ksd.update(step, gk, opt_ksd_state)

#                inner_updates.append(params)
        ksds.append(ksd)
        gradients.append(gk)
        utils.warn_if_nan(ksd)
        utils.warn_if_nan(gk)
    update_log = {
        "ksd": ksds,
        "gradients": gradients,
    }
    log = metrics.append_to_log(log, update_log)


# Bug


In [3]:
np.diff([1, 2, 4, 8])

[1, 2, 4, 8]

In [4]:
onp.diff([1, 2, 4, 8])

array([1, 2, 4])

# class thingy

In [5]:
class Test():
    def __init__(self, a=1):
        self.a = a
        
    def getf(self):
        def f():
            return self.a
        return f

In [6]:
t = Test(5)
f = t.getf()
f()

5

In [7]:
t.a = 3
f()

3

Right, of course this works: `a` is not a global variable, but `t.a` *is*, since the instance `t` is also available globally.

## jit `static_argnums` and classes

In [19]:
class Test():
    def __init__(self, y):
        self.y = y
        self.var = 0
    
    def tfun(self, x):
        return self.y + x**2 + self.var
    tfun = utils.verbose_jit(tfun, static_argnums=0)
    
    def change(self):
        self.var += 1

In [29]:
t = Test(5)

In [58]:
from copy import deepcopy
tcopy = deepcopy(t)
print(t == tcopy)
# indeed:
print(id(t) == id(tcopy))

print("----------")
print(id(t))
print(id(tcopy))

print("---------")
print("But the id of the class instance never changes, even if its objects do:")
oldid = id(t)
t.change()
print(oldid == id(t))
print()
print("This means you gotta be careful, since jit recompiles only if the object id changes.")

False
False
----------
139984662937048
139984663119408
---------
But the id of the class instance never changes, even if its objects do:
True

This means you gotta be careful, since jit recompiles only if the object id changes.


In [37]:
t.__eq__

<method-wrapper '__eq__' of Test object at 0x7f50b81bbdd8>

In [38]:
t.__hash__

<method-wrapper '__hash__' of Test object at 0x7f50b81bbdd8>

# numpy indexing

In [4]:
a = random.uniform(rkey, (20,))

In [5]:
a

DeviceArray([0.85417664, 0.16620052, 0.27605474, 0.48728156, 0.9920441 ,
             0.03015983, 0.21629429, 0.37687123, 0.63070035, 0.96144474,
             0.15203023, 0.92090297, 0.30555236, 0.29931295, 0.6925707 ,
             0.8542826 , 0.46517384, 0.7869307 , 0.99605286, 0.28018546],            dtype=float32)

In [7]:
idx = np.array([[1, 2, 3],
                [2, 3, 4]])

In [8]:
a[idx]

DeviceArray([[0.16620052, 0.27605474, 0.48728156],
             [0.27605474, 0.48728156, 0.9920441 ]], dtype=float32)

# test `vmap` row column behaviour

In [3]:
f = lambda x, y: 10*x + y

In [4]:
fv  = vmap(f,  (None, 0))
fvv = vmap(fv, (0, None))

x = np.array([1,2])
f_matrix = fvv(x, x)

In [5]:
f_matrix

DeviceArray([[11, 12],
             [21, 22]], dtype=int32)

# testing `stein.stein`

In [2]:
d = 3
dist = metrics.Gaussian(np.zeros(d), 1)
xs = dist.sample((100,))

### Case 1: 
Input is $f: \mathbb R^d \to \mathbb R^d$.
Then 
$$\mathcal A_p [f] (x) \in \mathbb R^{d \times d}$$
and 
$$\mathcal A_p^T [f] (x) \in \mathbb R$$

In [3]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return x*2 + a
f([3.,2.,1])

DeviceArray([7., 6., 5.], dtype=float32)

In [4]:
stein.stein(f, xs, dist.logpdf).shape

(3, 3)

In [7]:
stein.stein(f, xs, dist.logpdf, transposed=True)

DeviceArray(-0.5033575, dtype=float32)

### Case 2:
Input is $f: \mathbb R^d \to \mathbb R$.
Then 
$$\mathcal A_p [f] (x) \in \mathbb R^{d}.$$

Note that $\mathcal A_p^ T [f]$ makes no sense.

In [9]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return np.sum(x*2 + a)
f([3.,2.,1])

DeviceArray(18., dtype=float32)

In [10]:
stein.stein(f, xs, dist.logpdf).shape

(3,)

In [12]:
# stein.stein(f, xs, dist.logpdf, transposed=True) # throws an error

In [13]:
def f(x):
    x = np.array(x)
    assert x.shape == (d,)
    a = np.array([1,2,3])
    return np.einsum("i,j->ij",2*x, a)
f([3.,2.,1])

DeviceArray([[ 6., 12., 18.],
             [ 4.,  8., 12.],
             [ 2.,  4.,  6.]], dtype=float32)

In [17]:
jacfwd(f)([3.,2.,1.])

[DeviceArray([[2., 4., 6.],
              [0., 0., 0.],
              [0., 0., 0.]], dtype=float32),
 DeviceArray([[0., 0., 0.],
              [2., 4., 6.],
              [0., 0., 0.]], dtype=float32),
 DeviceArray([[0., 0., 0.],
              [0., 0., 0.],
              [2., 4., 6.]], dtype=float32)]

In [24]:
onp.einsum("iii->i", jacfwd(f)([3.,2.,1.]))

array([2., 4., 6.], dtype=float32)

In [30]:
np.einsum("ii", f([3,2,1]))

DeviceArray(20, dtype=int32)

# jit stuff 

In [None]:
def t(a, b):
    return a + b

In [None]:
jacfwd(t)(1., 2.)

In [None]:
@jit
def outer(defval, const=None):
    if const is None:
        const = defval
    
    def loss():
        return const
    return loss()
    

In [None]:
outer(5, None)

In [None]:
outer(5, 15)

In [None]:
outer(10, None)

## test mixture

In [None]:
d = 3
k = 5
rkey = random.split(rkey)[0]
means = random.uniform(rkey, shape=(k, d))
covs = random.uniform(rkey, shape=(k, d, d))
covs = np.einsum("kil,kjl->kij", covs, covs)
weights = np.array([1/3, 2/3, 2/3, 2/3, 1/3])

mix = metrics.GaussianMixture(means, covs, weights)

In [None]:
diffs = []
grid = np.arange(15)
grid = 3**grid
for i in grid:
    sample = mix.sample(shape=(i,))
    diffs.append(np.mean((np.cov(sample, rowvar=False) - mix.cov)**2 / mix.cov))
diffs = np.array(diffs)

In [None]:
plt.plot(grid, diffs, ".")
plt.yscale("log")
plt.xscale("log")

## test Gaussian

In [None]:
mean = np.array([1, 2])
cov = np.array([[1, 3], [3, 20]])

In [None]:
gauss = metrics.Gaussian(mean, cov)

In [None]:
metrics.Distribution

In [None]:
sample = gauss.sample(shape=(100,))
rsample = random.multivariate_normal(rkey, mean*3, cov/2, shape=(100,))

In [None]:
gauss.compute_metrics(sample)

In [None]:
gauss.compute_metrics(rsample)

## jax einsum floating point round-off error

In [None]:
import jax.numpy as jnp
import numpy as onp

Jax

In [None]:
values = jnp.array([[-5], [10]])
weights = jnp.array([1/3, 2/3])

In [None]:
jnp.einsum("i,id->d", weights, values)

In [None]:
jnp.sum(values.flatten() * weights)

Numpy

In [None]:
values = onp.array([[-5], [10]], dtype=np.float32)
weights = onp.array([1/3, 2/3], dtype=np.float32)

In [None]:
onp.einsum("i,id->d", weights, values)