# Making stuff

In [None]:
import sys
import os
sys.path.append("/home/lauro/code/msc-thesis/svgd")
import json
import collections
import itertools

import numpy as onp
from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd, value_and_grad
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt

import numpy as onp
import jax
import pandas as pd
import haiku as hk
import ot

import config

import utils
import metrics
import time
import plot
import svgd
import stein
import kernels

from jax.experimental import optimizers

rkey = random.PRNGKey(0)

# Vmap and aux

In [2]:
def f(x): return 2*x, 5

In [5]:
np.mean(vmap(f)(np.ones(5)))



DeviceArray(3.5, dtype=float64)

# Haiku identity

In [23]:
def id_fn(x): return x
identity = hk.transform(id_fn)

In [24]:
identity.init(rkey, 1.)

frozendict({})

# Timing

In [2]:
svgd = svgd.SVGD(**config.get_svgd_args(config.config))

In [19]:
class Time():
    def start(self):
        self.lapstart = self.zero = time.time()
        
    def lap(self, name):
        duration = time.strftime("%M:%S", time.gmtime(time.time()-self.lapstart))
        print(name, duration)
        self.lapstart = time.time()
        
    def stop(self):
        total = time.strftime("%M:%S", time.gmtime(time.time()-self.zero))
        print("Total time elapsed:", total)

In [20]:
d = metrics.Gaussian([0, 1], 1)
particles = d.sample(5000)

In [21]:
n = len(particles)
target_sample = svgd.target.sample(n)
t = Time()

t.start()
emd = metrics.wasserstein_distance(particles, target_sample)
t.lap("EMD")
sinkhorn_divergence = ot.bregman.empirical_sinkhorn_divergence(particles, target_sample, 1, metric="sqeuclidean")
t.lap("Sinkhorn")
ksd = stein.ksd_squared(particles, particles, svgd.target.logpdf, kernels.ard(0))
t.lap("KSD")
se_mean = np.mean((np.mean(particles, axis=0) - svgd.target.mean)**2)
t.lap("Mean")
se_var = np.mean((np.cov(particles, rowvar=False) - svgd.target.cov)**2)
t.lap("Var")
t.stop()

EMD 00:02
Sinkhorn 00:09
KSD 00:09
Mean 00:00
Var 00:01
Total time elapsed: 00:23


# haiku use subnetwork somewhere else

In [None]:
def encoder_fn(x):
    """can take kernel_params"""
    layer_sizes = [4, 4, 2]
    encoder = hk.nets.MLP(output_sizes=layer_sizes,
                    w_init=hk.initializers.VarianceScaling(scale=2.0),
                    activation=jax.nn.relu,
                    activate_final=False,
                    name="encoder")
    return encoder(x)
encoder = hk.transform(encoder_fn)

# unpacking `value_and_grad`

In [None]:
def f(x, y): return x+y, "aux"
value_and_grad(f, argnums=(0, 1), has_aux=True)(1.,2.)

# plot errorbars

In [None]:
def errorfill(x, y, yerr, color="r", alpha_fill=0.3, ax=None):
    ax = ax if ax is not None else plt.gca()
    if color is None:
        color = ax._get_lines.color_cycle.next()
    if np.isscalar(yerr) or len(yerr) == len(y):
        ymin = y - yerr
        ymax = y + yerr
    elif len(yerr) == 2:
        ymin, ymax = yerr
    ax.plot(x, y, color=color)
    ax.fill_between(x, ymax, ymin, color=color, alpha=alpha_fill)

In [None]:
x = np.linspace(0, 10, 100)
y = np.log(x)
var = np.log(x) / 3

In [None]:
errorfill(x, y, var)

# subsample

In [None]:
random.normal(rkey, (10,3)).split(2)

In [None]:
def subsample(key, array, n_subsamples, replace=True, axis=0):
    """
    Arguments
    ----------
    
    Returns
    ----------
    np.array of same shape as array except that the specified axis has length n_subsamples.
    consists of random samples from input array.
    """
    subsample_idx = random.choice(rkey, array.shape[axis], shape=(n_subsamples,), replace=replace)
    subsample = array.take(indices=subsample_idx, axis=axis)
    return subsample



In [None]:
particles = random.normal(rkey, (10, 2))
rkey = random.split(rkey)[0]

subsample_idx = random.choice(rkey, len(particles), shape=(5,), replace=False) # set replace=True?
subsample = particles[subsample_idx]
rkey = random.split(rkey)[0]

print(subsample)
rkey = random.split(rkey)[0]


# finiteness

In [None]:
t = np.array(1)
t = t / 0 # inf
t = t / t # NaN
t

In [None]:
np.isfinite(t)

# Generate means, covs, weights

In [None]:
A = onp.random.rand(4,4)

In [None]:
A

In [None]:
A.T

In [None]:
def generate_pd_matrix(dim):
    A = onp.random.rand(dim, dim)
    return onp.matmul(A, A.T)

In [None]:
def generate_parameters_for_gaussian(dim, k, mixture=True):
    means = onp.random.rand(k, dim) * 10 # random means in [0, 10]
    covs = [generate_pd_matrix(dim) for _ in range(k)]
    weights = onp.random.randint(1, 5, k)
    weights = weights / weights.sum()
    return means, covs, weights

In [None]:
generate_parameters_for_gaussian(2, 3)

In [None]:
d = metrics.GaussianMixture(*generate_parameters_for_gaussian(2, 3))

In [None]:
plot.plot_pdf(d.logpdf, (-5., 15.), "contour", num_gridpoints=500)
plt.scatter(d.means[:, 0], d.means[:, 1])

In [None]:
s = d.sample(10**3)

In [None]:
plot.bivariate_hist(s)

# Wasserstein

In [None]:
from scipy.spatial import distance

In [None]:
coords = [(35.0456, -85.2672),
          (35.1174, -89.9711),
          (35.9728, -83.9422),
          (36.1667, -86.7833)]
distance.cdist(coords, coords, 'minkowski')
# array([[ 0.    ,  4.7044,  1.6172,  1.8856],
#        [ 4.7044,  0.    ,  6.0893,  3.3561],
#        [ 1.6172,  6.0893,  0.    ,  2.8477],
#        [ 1.8856,  3.3561,  2.8477,  0.    ]])

# stuff

In [None]:
d = dict(chars="ab", nums=[1,2])
e = dict(words = ["bake", "tree"])

In [None]:
p = utils.dict_cartesian_product(**d)
# list(p)
# [x for x in p]

In [None]:
q = utils.dict_cartesian_product(**e)
# list(q)

In [None]:
for a, b in itertools.product(p, q):
    print(a)
    print(b)
    print()

In [None]:
[x for x in itertools.product(("a", "b", "c"), (1,2,3))]

# check if key in nested dict

In [None]:
def nested_dict_contains_key(ndict: dict, key):
    if key in ndict:
        return True
    else:
        for k, v in ndict.items():
            if isinstance(v, collections.Mapping):
                if nested_dict_contains_key(v, key):
                    return True
        return False

In [None]:
config.config