In [1]:
!pip install matplotlib
!pip install ..
#

[31mERROR: Directory '..' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.[0m


In [2]:
import numpy as onp

import jax
import jax.numpy as jnp

from jax import random, vmap

import matplotlib

from jax.scipy.stats import norm
from jax.scipy.stats import uniform

from functools import partial

from jax_sgmc import distributions
from jax_sgmc import potential
from jax_sgmc import data

Linear Regression
=========================

Problem
---------

$ y_i = \sum_{j=1}^N w_jx_i + \mathcal{N}(0, \sigma^2),\ i = 1, \ldots, M$

A multiple of samples are used to
deal with overfitting.


Reference Data
---------------

1. Draw weights $w_j$
2. For each sample $i$:
    1. Draw $N$ values of $x_j$
    2. Draw noise $\sigma$
    3. Calculate $y_j$

In [3]:
# Todo: This is not the correct form to handle batches? -> It is possible but it
# is kind of multiple batching

N = 200
M = 2

samples = 1000

sigma = 5.0

key = random.PRNGKey(0)
split, key = random.split(key)

w = random.uniform(split, minval=-1, maxval=1, shape=(N, 1))

def draw_samples(key, w):
    # For each sample draw a new set of predictor variables x and calculate
    # the response

    split1, split2, key = random.split(key, 3)

    noise = sigma * random.normal(split2, shape=(samples, 1))
    x = random.uniform(split1, minval=-10, maxval=10, shape=(samples, N))
    y = jnp.squeeze(jnp.matmul(x,w) + noise)

    return x, y

X, Y = draw_samples(key, w)

print(X)
print(Y)

print(X.shape)
print(Y.shape)





[[-6.82467    -3.1377554   9.639656   ...  8.063868   -0.37513018
   3.1858087 ]
 [ 9.66708    -1.7472482  -5.376272   ...  7.7683187  -1.1192489
   6.4086723 ]
 [-9.826147    8.313009   -6.7009068  ...  8.503883    2.4307108
  -3.287027  ]
 ...
 [-4.551344    6.4393926   6.4434814  ... -2.0396042   7.6109457
  -0.45919418]
 [-6.56801     3.8853645   7.398677   ... -9.011286    8.149996
   7.561407  ]
 [-5.969689    9.452255    6.8755054  ...  4.520278   -1.3018942
   1.1680746 ]]
[-5.55043488e+01  7.88076477e+01 -5.13533974e+01  5.21996078e+01
  9.96427727e+00  7.86780548e+00  1.11416138e+02 -5.57197456e+01
 -7.56359711e+01  3.69333344e+01  2.49893646e+01 -4.63896847e+00
  1.28366528e+01 -5.64106483e+01 -4.35241508e+01 -6.16251793e+01
 -3.77799072e+01  2.70470505e+01  2.62485504e+01 -5.53533287e+01
 -6.76102448e+01 -6.19348621e+00 -5.56775703e+01  7.30562508e-01
 -1.60615711e+01  2.46327248e+01  8.35883617e+00 -6.22692909e+01
 -1.86468410e+01  2.80861740e+01  3.23601875e+01  1.0333461

The reference data and the desired batch shape must be load into a data loader

In [4]:
data_loader = data.NumpyDataLoader(M, x=X, y=Y)
init_fn, batch_fn = data.random_reference_data(data_loader,
                                               cached_batches_count=1000)

Solution
---------

### Exact

In [5]:
w_sol, res, _, _ = jnp.linalg.lstsq(X, Y)

print(f"Max error: {jnp.max(jnp.abs(w - w_sol))}")

Max error: 1.9698679447174072


### SGLD


In [6]:
# First we need to define the (deterministic) model

def model(sample, parameters):
    weights = sample["w"]

    print(weights)
    print(parameters)

    predictions = jnp.dot(parameters, weights)
    return {"y": predictions}

# The combination of prior and distribution form the posterior, from which we
# sampling is performed.

def likelihood(model_results, extra_parameters, reference_data):
    sigma = extra_parameters["sigma"]
    y = model_results["y"]
    return norm.logpdf(y, loc=reference_data, scale=sigma)

def prior(sample):
    return uniform.logpdf(sample,
                          loc=-10 * jnp.ones(sample.shape),
                          scale=-10 * jnp.ones(sample.shape))

test_w = random.normal(key, (N, 1))

# Get a reference batch


state = init_fn()

state, (batch, _) = batch_fn(state)
print(state.current_line)

batch_x = batch['x']

print(batch_x)

state, (batch, _) = batch_fn(state)

batch_x = batch['x']

print(batch_x)

# We need to define how the posterior can be evaluated by a mini-batch of data
# instead of jsut for a single sample.

# minibatch_potential = potential.minibatch_potential_function(
#     prior,
#     likelihood,
#     strategy="map"
# )

# We test the potential evaluation for a single

1
[[ 6.573794    2.222278    8.376123    2.5250196  -1.9425583  -9.000201
  -1.7400074  -5.9429455  -9.914832   -5.158725    9.765239   -8.263466
  -2.797103   -6.7114472   3.99925     6.295717    3.5480452  -4.3435884
   7.129545    8.648238    9.031294    6.110046    4.4556665  -2.502308
   1.3491225  -4.4407296   8.3746     -6.2891316   7.1446013   6.109264
  -7.8511906   5.287485   -7.9173255   4.6929526   9.485822   -0.14535427
  -7.526872    8.142977   -8.109644    4.4521165   9.723131   -3.629911
  -0.68686247 -2.9518151  -0.2883196   8.191729    2.6522827  -3.9305782
  -3.896687   -1.034348    3.5864258  -4.6831703  -4.1329217   3.475809
   4.511397   -0.48801422 -5.920012   -1.4000297  -1.219554    3.2512093
  -5.9770727   4.486289   -6.9371033  -7.455523    7.515621   -7.8590083
   9.035309    2.2870755  -4.119897   -4.1224074  -2.786374   -5.377989
   4.689269    9.627218    1.9808292  -5.228958   -9.240934   -3.1479669
  -5.1801014  -7.6068544   6.768017   -9.016706   -8.32