In [131]:
!pip install matplotlib
!pip install ..
#

Processing /Users/Paul/Lokal/SGMC
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: jax-sgmc
  Building wheel for jax-sgmc (setup.py) ... [?25l- done
[?25h  Created wheel for jax-sgmc: filename=jax_sgmc-0.0.1-py3-none-any.whl size=5957 sha256=1d70afd551c868cf269d094c352a34507bd64736c3cfae7394e97822cc006e5f
  Stored in directory: /private/var/folders/hr/fcqkhn7x07b8h6h_glqwb_cw0000gn/T/pip-ephem-wheel-cache-7b6_u0oc/wheels/db/4c/b1/9039d8f42e60cd073ec30f51b1820f15c7bd929ae2a04a08b6
Successfully built jax-sgmc
Installing collected packages: jax-sgmc
  Attempting uninstall: jax-sgmc
    

In [125]:
import numpy as onp

import jax
import jax.numpy as jnp

from jax import random, vmap

import matplotlib

from jax.scipy.stats import norm
from jax.scipy.stats import uniform

from functools import partial

from jax_sgmc import distributions
from jax_sgmc import potential

from jax_sgmc import data

Linear Regression
=========================

Problem
---------

$ y_i = \sum_{j=1}^N w_jx_i + \mathcal{N}(0, \sigma^2),\ i = 1, \ldots, M$

A multiple of samples are used to
deal with overfitting.


Reference Data
---------------

1. Draw weights $w_j$
2. For each sample $i$:
    1. Draw $N$ values of $x_j$
    2. Draw noise $\sigma$
    3. Calculate $y_j$

In [126]:
# Todo: This is not the correct form to handle batches? -> It is possible but it
# is kind of multiple batching

N = 200
M = 100

samples = 1000

sigma = 5.0

key = random.PRNGKey(0)
split, key = random.split(key)

w = random.uniform(split, minval=-1, maxval=1, shape=(N, 1))

@partial(vmap, in_axes=(0, None), out_axes=(0, 0))
def draw_samples(key, w):
    # For each sample draw a new set of predictor variables x and calculate
    # the response

    split1, split2, key = random.split(key, 3)

    noise = sigma * random.normal(split2, shape=(M, 1))
    x = random.uniform(split1, minval=-10, maxval=10, shape=(M, N))
    y = jnp.matmul(x,w) + noise

    return x, y

keys = random.split(key, num=samples)

X, Y = draw_samples(keys, w)

The reference data and the desired batch shape must be load into a data loader

In [127]:
data_loader = data.PreloadReferenceData(Y,
                                        parameters=X,
                                        batch_size=1)

Solution
---------

### Exact

In [128]:
X = jnp.reshape(X, newshape=(-1, N))
Y = jnp.reshape(Y, newshape=(-1, 1))

w_sol, res, _, _ = jnp.linalg.lstsq(X, Y)

print(f"Max error: {jnp.max(jnp.abs(w - w_sol))}")

Max error: 0.007030069828033447


### SGLD


In [129]:
# First we need to define the (deterministic) model

def model(sample, parameters):
    weights = sample["w"]

    print(weights)
    print(parameters)

    predictions = jnp.dot(parameters, weights)
    return {"y": predictions}

# The combination of prior and distribution form the posterior, from which we
# sampling is performed.

def likelihood(model_results, extra_parameters, reference_data):
    sigma = extra_parameters["sigma"]
    y = model_results["y"]
    return norm.logpdf(y, loc=reference_data, scale=sigma)

def prior(sample):
    return uniform.logpdf(sample,
                          loc=-10 * jnp.ones(sample.shape),
                          scale=-10 * jnp.ones(sample.shape))

test_w = random.normal(key, (N, 1))

# Get a reference batch

batch_y, batch_x = data_loader.get_random_batch()

print(model({"w": test_w}, batch_x[0,::]))

# We need to define how the posterior can be evaluated by a mini-batch of data
# instead of jsut for a single sample.

# minibatch_potential = potential.minibatch_potential_function(
#     prior,
#     likelihood,
#     strategy="map"
# )

# We test the potential evaluation for a single


[[ 8.69366527e-01]
 [ 9.02669370e-01]
 [ 5.03285646e-01]
 [ 5.47213495e-01]
 [ 6.59128487e-01]
 [-5.30288458e-01]
 [-3.96815211e-01]
 [-4.01674286e-02]
 [ 1.08454514e+00]
 [ 3.12125366e-02]
 [-4.71260190e-01]
 [ 1.05747533e+00]
 [ 1.03348887e+00]
 [-3.87305133e-02]
 [-5.99555731e-01]
 [-4.14275289e-01]
 [ 8.30999166e-02]
 [-3.58195417e-02]
 [-3.61374915e-01]
 [-4.32906561e-02]
 [ 7.34955132e-01]
 [ 1.83077300e+00]
 [ 1.56501675e+00]
 [-7.87699163e-01]
 [ 6.97018087e-01]
 [-1.41609466e+00]
 [-1.94630516e+00]
 [-5.82420588e-01]
 [ 1.26510775e+00]
 [-3.26003194e-01]
 [ 6.16869092e-01]
 [-1.23756659e+00]
 [ 4.71711636e-01]
 [-1.05441004e-01]
 [ 6.57151759e-01]
 [ 4.87126470e-01]
 [ 6.23257458e-01]
 [ 2.44133770e-01]
 [-3.44233155e+00]
 [-1.38190782e+00]
 [-1.31327879e+00]
 [ 2.43646771e-01]
 [ 1.43044114e+00]
 [ 8.46537709e-01]
 [ 1.87248278e+00]
 [-1.83469105e+00]
 [ 3.42182964e-01]
 [ 6.95943832e-01]
 [-5.25920868e-01]
 [ 6.45238638e-01]
 [-1.25352204e+00]
 [-1.47869229e+00]
 [-1.3817981

AssertionError: Currently not implemented