# [Discrete Walk-Jump Sampling](https://arxiv.org/abs/2306.12360)


## Imports

In [None]:
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import functools
from tqdm import trange

import jax
import jax.config
jax.config.update("jax_enable_x64", True)

import chex
import jax.numpy as jnp
import flax.linen as nn
from clu import parameter_overview
import optax
import tqdm
from tensorflow_probability.substrates import jax as tfp
import matplotlib.pyplot as plt
import matplotlib.animation

plt.rcParams["animation.html"] = "jshtml"


## Problem Setup

Fundamentally, discrete walk-jump sampling (dWJS) is a method for sampling noisy latents from an energy-based model, and denoising them with a neural network.

Thus, to understand dWJS, we first need to understand energy-based models and neural empirical Bayes.

## Energy-Based Models

Energy-based models (EBMs) model the probability of a data point $y$ with an energy function $E_\theta(y)$ parameterized by $\theta$:
$$
p_\theta(y) = \frac{\exp(-E_\theta(y))}{Z(\theta)}
$$
where $Z(\theta)$ is the normalizing constant (called the partition function):
$$
Z(\theta) = \int \exp(-E_\theta(y)) dy
$$
Low energy samples have high probability, and vice versa.


### How to train an EBM? 

We can train an EBM by minimizing the KL divergence between the true distribution $p$ and the model distribution $p_\theta$. This is equivalent to maximizing the expected value of log-likelihood of the data sampled from $p(y)$:
$$
\theta^* = \argmin_\theta \text{KL}(p \ || \ p_\theta) = \argmax_\theta \mathbb{E}_{y\sim p(y)} [\log p_\theta(y)] = \argmin_\theta \mathbb{E}_{y\sim p(y)} [-\log p_\theta(y)]
$$

This is all standard so far. The problem is that $Z(\theta)$ is intractable to compute:
$$
\mathbb{E}_{y\sim p(y)} [-\log p_\theta(y)] = \mathbb{E}_{y\sim p(y)} [E_\theta(y)] + \log Z(\theta) 
$$ 
If we use gradient descent to optimize $\theta$, we need to compute:
$$
\nabla_\theta \mathbb{E}_{y\sim p(y)} [-\log p_\theta(y)] = \mathbb{E}_{y\sim p(y)} [\nabla_\theta E_\theta(y)] + \nabla_\theta\log Z(\theta) 
$$ 
The first term is easy to compute, but the second term is (usually) intractable. The common approach is to approximate the second term via MCMC sampling:
$$
\begin{aligned}
\nabla_\theta\log Z(\theta) &= \frac{\nabla_\theta Z(\theta)}{Z(\theta)} \\
&= \frac{\nabla_\theta \int \exp(-E_\theta(y)) dy}{Z(\theta)} \\
&= \frac{\int \nabla_\theta \exp(-E_\theta(y)) dy}{Z(\theta)} \\
&= \frac{\int - \nabla_\theta E_\theta(y) \exp(-E_\theta(y)) dy}{Z(\theta)} \\
&= \frac{\int - \nabla_\theta E_\theta(y) Z(\theta)  p_\theta(y) dy}{Z(\theta)} \\
&= \int - \nabla_\theta E_\theta(y) \ p_\theta(y) dy \\
&= \mathbb{E}_{y\sim p_\theta(y)} [- \nabla_\theta E_\theta(y)] \\
\\
\end{aligned}
$$

Thus, we seek to minimize:
$$
\mathbb{E}_{y\sim p(y)} [-\log p_\theta(y) = 
\mathbb{E}_{y\sim p(y)} [\nabla_\theta E_\theta(y)] - \mathbb{E}_{y\sim p_\theta(y)} [ \nabla_\theta E_\theta(y)]  
$$
We are seeking to minimize the energy of positive samples (from the data distribution) and maximize the energy of negative samples (from the model distribution). This is why this approach is also called contrastive divergence.

# Langevin MCMC

We have computed an estimator for the gradient of the negative log-likelihood (NLL) loss.

Note that this estimator requires us to sample from the model distribution $p_\theta(y)$ at each iteration. We perform this sampling from $p_\theta(y)$ via Langevin MCMC. Langevin MCMC is similar to a noisy version of gradient ascent on the log-likelihood.

Initialize a sample $y_0$ randomly.
Then, for each iteration $t$, compute:
$$
y_{t+1} = y_t + \delta \nabla_{y_t} \log p_\theta(y_t)  + \sqrt{2\delta} \epsilon_t
$$
where $\epsilon_t \sim \mathcal{N}(0, I)$.
Then, as $t \rightarrow \infty$, $y_t$ will appear to be sampled from $p_\theta(y)$.

Note that the partition function $Z(\theta)$ does not show up in the sampling procedure:
$$
\nabla_{y_t} \log p_\theta(y_t) = - \nabla_{y_t} E_\theta(y_t)
$$

In [None]:
@functools.partial(jax.jit, static_argnames=("grad_log_prob_fn", "num_steps"))
def langevin_sample(grad_log_prob_fn: Callable[[chex.Array], float], init: chex.Array, delta: float, rng: chex.PRNGKey, num_steps: int):
    """Langevin sampling from a given log probability function."""

    def one_step_langevin(y_t: chex.Array, rng: chex.PRNGKey):
        eps = jax.random.normal(rng, y_t.shape)
        y_next = y_t + delta * grad_log_prob_fn(y_t) + jnp.sqrt(2 * delta) * eps
        return y_next, y_next

    sampling_rngs = jax.random.split(rng, num_steps)
    _, samples = jax.lax.scan(one_step_langevin, init, xs=sampling_rngs, length=len(sampling_rngs))
    return samples

## Example Time!

To illustrate the math, we will use a simple 1D example. Let's assume that the data distribution is a mixture of two Gaussians centered at -1 and 1, respectively:
$$
p(y) = \frac{1}{2} \mathcal{N}(y; -1, 0.5) + \frac{1}{2} \mathcal{N}(y; 1, 0.5)
$$

In [None]:
tfd = tfp.distributions
px = tfd.Categorical(probs=[0.5, 0.5])
p = tfd.Mixture(
    cat=px,
    components=[
        tfd.Normal(loc=-1., scale=0.5),
        tfd.Normal(loc=+1., scale=0.5),
    ]
)

# Plot the PDF.
y = jnp.linspace(-5., 5., int(1e4))
plt.grid()
plt.plot(y, p.prob(y))
plt.xlabel('y')
plt.ylabel('p(y)')
plt.title('True PDF')
plt.show();

We can visualize the Langevin MCMC sampling process below:

In [None]:
rng = jax.random.PRNGKey(0)
grad_log_prob_fn = jax.jit(jax.grad(lambda y: p.log_prob(y).squeeze()))

In [None]:
sampling_rng, rng = jax.random.split(rng)
delta = 0.1
langevin_samples_from_2 = langevin_sample(grad_log_prob_fn, init=2 * jnp.ones((1,)), delta=delta, rng=sampling_rng, num_steps=500)

fig, ax = plt.subplots()
ax.grid()
ax.set_xlim(-5., 5.)
ax.set_ylim(0., 1.)
ax.set_xlabel('y')
ax.set_ylabel('p(y)')
scatter = ax.scatter([], [], lw=2, color='C0')

def animate(i: int):
    offsets = [langevin_samples_from_2[:i], p.prob(langevin_samples_from_2[:i])]
    offsets = jnp.stack(offsets, axis=-1).squeeze()
    scatter.set_offsets(offsets)
    # Adjust opacity.
    if i > 0:
        scatter.set_alpha(jnp.arange(i) ** 2 / i ** 2)
    ax.set_title(r'Langevin Sampling Starting from 2 with $\delta={}$: Step {}'.format(delta, i))
    return (scatter,)

anim = matplotlib.animation.FuncAnimation(fig, animate, frames=100, interval=100, blit=True)
plt.close()
anim

In [None]:
sampling_rng, rng = jax.random.split(rng)
langevin_samples_from_neg_2 = langevin_sample(grad_log_prob_fn, init=-2*jnp.ones((1,)), delta=delta, rng=sampling_rng, num_steps=500)

fig, ax = plt.subplots()
ax.grid()
ax.set_xlim(-5., 5.)
ax.set_ylim(0., 1.)
ax.set_xlabel('y')
ax.set_ylabel('p(y)')
scatter = ax.scatter([], [], lw=2, c='C1')

def animate(i: int):
    offsets = [langevin_samples_from_neg_2[:i], p.prob(langevin_samples_from_neg_2[:i])]
    offsets = jnp.stack(offsets, axis=-1).squeeze()
    scatter.set_offsets(offsets)
    # Adjust opacity.
    # scatter.set_sizes(100 * jnp.ones(i))
    if i > 0:
        scatter.set_alpha(jnp.arange(i) ** 2 / i ** 2)
    ax.set_title(r'Langevin Sampling Starting from -2 with $\delta={}$: Step {}'.format(delta, i))
    return (scatter,)

anim = matplotlib.animation.FuncAnimation(fig, animate, frames=100, interval=100, blit=True)
plt.close()
anim

We can check that the histogram of the samples from Langevin MCMC match the data distribution somewhat closely:

In [None]:
plt.hist(langevin_samples_from_2.flatten(), bins=20, density=True, alpha=0.5, color='C0', label='Starting from 2')
plt.hist(langevin_samples_from_neg_2.flatten(), bins=20, density=True, alpha=0.5, color='C1', label='Starting from -2')
plt.grid()
plt.xlabel('y')
plt.ylabel('p(y)')
plt.legend()
plt.title('Histograms of Langevin Samples')
plt.show();

For our EBM, we will use a simple 2-layer neural network for the energy function:
$$
E_\theta(y) = W_1\text{softplus}(W_0y + b_0) + b_1
$$

In [None]:
class EnergyBasedModel(nn.Module):
    """A simple energy-based model."""
    hidden_size: int

    @nn.compact
    def __call__(self, y: chex.Array) -> chex.Array:
        if len(y.shape) <= 1:
            y = jnp.expand_dims(y, axis=0)
        y = nn.Dense(self.hidden_size)(y)
        y = jax.nn.softplus(y)
        y = nn.Dense(self.hidden_size)(y)
        y = jax.nn.softplus(y)
        y = nn.Dense(1)(y)
        y = jnp.squeeze(y, axis=-1)
        return y

# Initialize the model.
model = EnergyBasedModel(hidden_size=10)
dummy_input = jnp.ones((1,))
init_params = model.init(rng, dummy_input)
energy_fn = jax.jit(model.apply)

# Overview of the model parameters.
print(parameter_overview.get_parameter_overview(init_params))

We can visualize the unnormlized probability distribution of the EBM below:

In [None]:
y = jnp.linspace(-10., 10., int(1e4))
y_probs = jax.vmap(lambda y: jnp.exp(-energy_fn(init_params, y)))(y)
y_probs /= y_probs.sum() * (y[-1] - y[0]) / len(y)
plt.grid()
plt.plot(y, y_probs, color='C2')
plt.xlabel('y')
plt.ylabel('p_model(y)')
plt.title('(Approximately Normalized) Model PDF')
plt.show();

In [None]:
def create_grad_log_prob_fn(params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float]) -> Callable[[chex.Array], chex.Array]:
    """Creates a function that computes the gradient of the log probability under the EBM."""
    def grad_log_prob_fn(y: chex.Array) -> chex.Array:
        return -jax.grad(lambda y: energy_fn(params, y).squeeze())(y)
    return grad_log_prob_fn

sampling_rng, rng = jax.random.split(rng)
langevin_samples_from_model = langevin_sample(create_grad_log_prob_fn(init_params, energy_fn), init=jnp.zeros((1,)), delta=1, rng=sampling_rng, num_steps=10000)
plt.hist(langevin_samples_from_model.flatten(), bins=20, density=True, alpha=0.5, color='C2', label='Model Samples')
plt.grid()
plt.xlabel('y')
plt.ylabel('p(y)')
plt.legend()
plt.title('Histograms of Langevin Samples')
plt.show();

We now train the EBM using the gradient estimator from above. We use the last sample from Langevin MCMC as the negative sample.
We can simply use automatic differentiation to compute the gradient of the loss!

In [None]:
@functools.partial(jax.jit, static_argnames=("energy_fn", "num_sampling_steps", "take_every_sample", "burn_in_samples"))
def ebm_loss_fn(params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float], y_true_samples: chex.Array, rng: chex.PRNGKey,
            delta: float, num_sampling_steps: int, take_every_sample: int, burn_in_samples: int) -> float:
    """Computes the EBM loss function."""
    init_rng, rng = jax.random.split(rng)
    init = jax.random.normal(init_rng, y_true_samples[0].shape)
    sampling_rng, rng = jax.random.split(rng)
    y_model_samples = langevin_sample(create_grad_log_prob_fn(params, energy_fn), init=init, delta=delta, rng=sampling_rng, num_steps=num_sampling_steps)
    y_model_samples = y_model_samples[burn_in_samples:]
    y_model_samples = y_model_samples[::take_every_sample]
    # We don't differentiate through the sampling procedure.
    y_model_samples = jax.lax.stop_gradient(y_model_samples)
    energy_fn_vmapped = jax.vmap(lambda y: energy_fn(params, y))
    return energy_fn_vmapped(y_true_samples).mean() - energy_fn_vmapped(y_model_samples).mean()

We are ready to train our model!

In [None]:
def train_ebm_model(init_params: optax.Params, rng: chex.PRNGKey, num_training_steps: int, **loss_kwargs) -> optax.Params:
    """Train the EBM model using the Adam optimizer."""
    @jax.jit
    def train_step(params: optax.Params, opt_state: optax.OptState, y_true_samples: chex.Array, rng: chex.PRNGKey) -> Tuple[optax.Params, optax.OptState, float]:
        loss, grad = jax.value_and_grad(ebm_loss_fn, has_aux=False)(params, energy_fn, y_true_samples, rng, **loss_kwargs)
        updates, opt_state = tx.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    tx = optax.adam(5e-4)
    opt_state = tx.init(init_params)

    params_at_steps = {
        0: init_params,
    }
    
    params = init_params
    for step in tqdm.trange(num_training_steps):
        step_rng, samples_rng, rng = jax.random.split(rng, num=3)
        y_true_samples = p.sample(32, seed=samples_rng)

        params, opt_state, loss = train_step(params, opt_state, y_true_samples, step_rng)

        # Log the training progress.        
        if step % 500 == 0:
            params_at_steps[step] = params
        
        if step % 2000 == 0:
            print('Step {}: Loss = {}'.format(step, loss))
    
    return params, params_at_steps

train_rng = jax.random.PRNGKey(0)
energy_params, energy_params_at_steps = train_ebm_model(init_params, train_rng, num_training_steps=30000, delta=0.1, num_sampling_steps=10, take_every_sample=2, burn_in_samples=1)

We can visualize the learnt (unnormalized) probability distribution of the EBM:

In [None]:
fig, ax = plt.subplots()
ax.grid()
ax.set_xlim(-5., 5.)
ax.set_ylim(0., 1.)
ax.set_xlabel('y')
ax.set_ylabel('p_model(y)')
line, = ax.plot([], [], lw=2, color='C2')
steps = sorted(energy_params_at_steps.keys())

def animate(i: int):
    y = jnp.linspace(-5., 5., int(1e3))
    y_probs = jax.vmap(lambda y: jnp.exp(-energy_fn(energy_params_at_steps[steps[i]], y)))(y)
    y_probs /= y_probs.sum() * (y[-1] - y[0]) / len(y)
    line.set_data(y, y_probs)
    ax.set_title('(Approximately Normalized) Model PDF: Step {}'.format(steps[i]))
    return (line,)

anim = matplotlib.animation.FuncAnimation(fig, animate, frames=len(steps), interval=100, blit=True)
plt.close()
anim

## Neural Empirical Bayes

This section follows the work of [Saeed Saremi and Aapo Hyvarinen](https://arxiv.org/abs/1903.02334).

Consider an observation denoted by the random variable $X \in \mathbb{R}^d$, and a noisy observation of $X$ denoted by $Y \in \mathbb{R}^d$:
$$
Y = X + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma^2 I_d)
$$
Thus, we are given that $p_{Y|X}$ is a Gaussian.

Given the observation $Y = y$, the Bayes' least squares estimator for $X$ is:
$$
\hat{X}(y) = \mathbb{E}[X \ | \ Y = y]
$$
Computing this estimator requires knowledge of $p_{X|Y} \propto p_{Y|X}\cdot p_X$ by Bayes' rule.
Thus, it seems that we need to know $p_X$ to compute this estimator.

The trick (figured out by [Robbins](https://link.springer.com/chapter/10.1007/978-1-4612-0919-5_26) and [Miyasawa](https://mit.primo.exlibrisgroup.com/discovery/openurl?institution=01MIT_INST&vid=01MIT_INST:MIT&rft.epage=188&rft_val_fmt=info:ofi%2Ffmt:kev:mtx:journal&rft.stitle=B%20INT%20STATIST%20INST&rft.volume=38&rfr_id=info:sid%2Fwebofscience.com:WOS:WOS&rft.jtitle=BULLETIN%20OF%20THE%20INTERNATIONAL%20STATISTICAL%20INSTITUTE&rft.aufirst=K&rft.genre=article&rft.issue=4&rft.pages=181-188&url_ctx_fmt=info:ofi%2Ffmt:kev:mtx:ctx&rft.aulast=MIYASAWA&url_ver=Z39.88-2004&rft.auinit=K&rft.date=1960&rft.spage=181&rft.atitle=AN%20EMPIRICAL%20BAYES%20ESTIMATOR%20OF%20THE%20MEAN%20OF%20A%20NORMAL%20POPULATION&rft.issn=0074-8609)) turns out that we can compute this estimator without knowing $p_X$!

For all $x$ and $y$, we have:
$$
p_{Y|X}(y|x) = \mathcal{N}(y; x, \sigma^2) = \frac{1}{(2\pi\sigma^2)^{\frac{d}{2}}} \exp\left(-\frac{\|y - x\|^2}{2\sigma^2} \right)
$$
so:
$$
\begin{aligned}
\nabla_y p_{Y|X}(y|x) &= -\frac{y - x}{\sigma^2} p_{Y|X}(y|x) \\
\implies (x - y) p_{Y|X}(y|x) &= \sigma^2 \nabla_y p_{Y|X}(y|x) \\
\implies \int (x - y) p_{Y|X}(y|x) p_X(x) dx &= \sigma^2 \int \nabla_y p_{Y|X}(y|x) p_X(x) dx
\end{aligned}
$$
Note that, by Bayes' rule: $p_{Y|X}(y|x) p_X(x) = p_{X,Y}(x, y) = p_{X|Y}(x|y) p_Y(y)$.  
Also, by definition of the marginals: $\int p_{X,Y}(x, y) dx = p_{Y}(y)$.  
For the left-hand side, we have:
$$
\begin{aligned}
\int (x - y) p_{Y|X}(y|x) p_X(x) dx &= \int x p_{Y|X}(y|x) p_X(x) dx - \int  y p_{Y|X}(y|x) p_X(x) dx \\
&= \int x p_{X,Y}(x, y) dx - \int  y p_{X,Y}(x, y) dx \\
&=  p_Y(y) \left(\int x p_{X|Y}(x|y) dx - y \int  p_{X|Y}(x|y) dx\right) \\
&=  p_Y(y) \left(\mathbb{E}[X \ | \ Y = y] - y \right) \\
&=  p_Y(y) \left(\hat{X}(y) - y \right)
\end{aligned}
$$
For the right-hand side, we have:
$$
\begin{aligned}
\sigma^2 \int \nabla_y p_{Y|X}(y|x) p_X(x) dx &= \sigma^2 \nabla_y \int p_{Y|X}(y|x) p_X(x) dx \\
&=  \sigma^2 \nabla_y \int p_{X,Y}(x, y) dx  \\
&=  \sigma^2 \nabla_y p_{Y}(y)
\end{aligned}
$$
Thus,
$$
\begin{aligned}
p_Y(y) \left(\hat{X}(y) - y \right) &= \sigma^2 \nabla_y p_{Y}(y)
\\
\implies \hat{X}(y) &= y + \sigma^2 \frac{\nabla_y p_{Y}(y)}{p_Y(y)}
\\
\implies \hat{X}(y) &= y + \sigma^2 \nabla_y \log p_{Y}(y)
\end{aligned}
$$
Thus, the estimator $\hat{X}(y)$ can be computed without knowledge of $p_X$, only the knowledge of the score $\nabla_y \log p_{Y}(y)$ is required.

Now, there are two approaches to learning the score function $\nabla_y \log p_{Y}(y)$.
The first is to approximate $p_{Y}(y)$ by an EBM:
$$
p_{Y}(y) \approx \frac{\exp(-E_\theta(y))}{Z(\theta)} \implies \nabla_y \log p_{Y}(y) = - \nabla_y E_\theta(y)
$$
The EBM can be learned using contrastive divergence described before.
Then, the learned EBM can be used as a denoiser, by denoising $Y$ to obtain an estimate of $X$:
$$
\hat{X}(y) = y - \sigma^2 \nabla_y E_\theta(y)
$$

The second approach, proposed in this Discrete Walk-Jump Sampling paper, is to directly parametrize the score function by a 'denoising' neural network:
$$
g_\phi(y) \approx \nabla_y \log p_{Y}(y)
$$
Denoising $Y$ as before:
$$
\hat{X}(y) = y + \sigma^2 g_\phi(y)
$$
The denoising network $g_\phi$ can be trained from observations of $X$ and adding noise to obtain examples $Y$:
* Sample $X_i \sim p_X$.
* Sample $\epsilon_j \sim \mathcal{N}(0, \sigma^2 I_d)$.
* Compute $Y_{ij} = X_i + \epsilon_j$.
* Optimize $\phi$:
$$
\phi^* = \argmin_\phi \sum_{i,j} \|X_i - (Y_{ij} + g_\phi(Y_{ij})) \|^2
$$

In [None]:
# Define the score model that predicts the score.
class ScoreNetwork(nn.Module):
    """A simple score neural network."""
    hidden_size: int

    @nn.compact
    def __call__(self, y: chex.Array) -> chex.Array:
        if len(y.shape) <= 1:
            y = jnp.expand_dims(y, axis=0)
        init_dims = y.shape[-1]
        y = nn.Dense(self.hidden_size)(y)
        y = jax.nn.softplus(y)
        y = nn.Dense(self.hidden_size)(y)
        y = jax.nn.softplus(y)
        y = nn.Dense(init_dims)(y)
        return y

# Initialize the score model.
score_model = ScoreNetwork(hidden_size=10)
dummy_input = jnp.ones((1,))
init_score_params = score_model.init(rng, dummy_input)
score_fn = jax.jit(score_model.apply)

# Overview of the model parameters.
print(parameter_overview.get_parameter_overview(init_score_params))

In [None]:
@functools.partial(jax.jit, static_argnames=("score_fn", "num_noise_samples"))
def score_loss_fn(
    params: optax.Params, score_fn: Callable[[optax.Params, chex.Array], chex.Array], x_true_samples: chex.Array, rng: chex.PRNGKey, noise_std: float, num_noise_samples: int) -> float:
    """Computes the denoising loss."""
    assert len(x_true_samples.shape) == 2
    num_true_samples, num_dims = x_true_samples.shape

    noise_rng, rng = jax.random.split(rng)
    noise = noise_std * jax.random.normal(noise_rng, (num_noise_samples, num_dims))
    y_samples = x_true_samples[:, None, :] + noise[None, ...]
    assert y_samples.shape == (num_true_samples, num_noise_samples, num_dims)

    predictions = score_fn(params, y_samples)
    assert predictions.shape == (num_true_samples, num_noise_samples, num_dims)

    l2_loss = jax.vmap(lambda x, ys, preds: jnp.linalg.norm(x - (ys + noise_std ** 2 * preds), axis=-1).mean())(x_true_samples, y_samples, predictions)
    assert l2_loss.shape == (num_true_samples,)

    return l2_loss.mean()
    

In [None]:
def train_score_model(init_params: optax.Params, rng: chex.PRNGKey, num_training_steps: int, **loss_kwargs) -> optax.Params:
    """Train the score model using the Adam optimizer."""
    @jax.jit
    def train_step(params: optax.Params, opt_state: optax.OptState, x_true_samples: chex.Array, rng: chex.PRNGKey) -> Tuple[optax.Params, optax.OptState, float]:
        loss, grad = jax.value_and_grad(score_loss_fn, has_aux=False)(params, score_fn, x_true_samples, rng, **loss_kwargs)
        updates, opt_state = tx.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    tx = optax.adam(5e-4)
    opt_state = tx.init(init_params)

    params_at_steps = {
        0: init_params,
    }
    
    params = init_params
    for step in tqdm.trange(num_training_steps + 1):
        step_rng, samples_rng, rng = jax.random.split(rng, num=3)
        x_true_samples = px.sample(32, seed=samples_rng)
        x_true_samples = x_true_samples[:, None]

        params, opt_state, loss = train_step(params, opt_state, x_true_samples, step_rng)

        # Log the training progress.        
        if step % 500 == 0:
            params_at_steps[step] = params
        
        if step % 2000 == 0:
            print('Step {}: Loss = {}'.format(step, loss))
    
    return params, params_at_steps

train_rng = jax.random.PRNGKey(0)
noise_std = 0.5
score_params, score_params_at_steps = train_score_model(init_score_params, train_rng, num_training_steps=30000, noise_std=noise_std, num_noise_samples=10)

Let's check that the score function works!

In [None]:
noise_rng, rng = jax.random.split(rng)
noise = noise_std * jax.random.normal(noise_rng, (10, 1))
x_true = jnp.asarray([[-1.], [1.]])
y = x_true[None, :] + noise[:, None]
y = y.transpose((1, 0, 2)).reshape((-1, 1))
preds = score_fn(score_params, y)
x = y + noise_std ** 2 * preds
labels = jnp.where(x < 0.5, 0, 1)
plt.grid()
plt.scatter(y, x, color=['C0' if label == 0 else 'C1' for label in labels])
plt.xlabel('y')
plt.ylabel('x')
plt.title('Denoising Model Predictions')
plt.show();

## Walk-Jump Sampling

The idea behind Walk-Jump Sampling is that it is easier to walk in the space of noisy observations $Y$ than in the space of clean observations $X$. The noise helps connect different modes of the distribution. Given any noisy observation $Y$, we can always go back to the clean observation $X$ by denoising.

* Walk in noisy observation space with Langevin MCMC:
$$
    y_t = y_{t-1} + \delta \nabla_{y_{t-1}} \log p_Y(y_{t-1})  + \sqrt{2\delta} \epsilon_t
$$
* Jump to clean observation (at any time $\tau$):
$$
    x_\tau = y_\tau + \sigma^2 \nabla_{y_\tau} \log p_Y(y_\tau)
$$

Note that both the walk and jump steps need an estimate of the score function $\nabla_{y} \log p_Y(y)$.
We have choices for how we parametrize these in each step. Here, they find that using an EBM for the walker, and a denoising network for the jumper works best:
* EBM:
$$
p_Y(y) = \frac{\exp(-E_\theta(y))}{Z(\theta)}
$$
* Denoiser:
$$
\nabla_y \log p_{Y}(y) \approx g_\phi(y)
$$

Note that unlike diffusion, every single sample from walk-jump sampling is approximately from the data distribution $p_X$.

In [None]:
@functools.partial(jax.jit, static_argnames=("energy_fn", "score_fn", "num_steps", "noise_std"))
def walk_jump_sampling(energy_fn_params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float], score_fn_params: optax.Params, score_fn: Callable[[optax.Params, chex.Array], chex.Array], rng: chex.PRNGKey, delta: float, num_steps: int, noise_std: float):
    """Performs walk-jump sampling."""
    grad_log_prob_fn = create_grad_log_prob_fn(energy_fn_params, energy_fn)
    noisy_observations = langevin_sample(grad_log_prob_fn, init=jnp.zeros((1,)), delta=delta, rng=rng, num_steps=num_steps)
    scores = score_fn(score_fn_params, noisy_observations)
    denoised_observations = noisy_observations + noise_std ** 2 * scores
    return noisy_observations, denoised_observations


walk_jump_sampling_rng, rng = jax.random.split(rng)
noisy_observations, denoised_observations = walk_jump_sampling(energy_params, energy_fn, score_params, score_fn, rng=walk_jump_sampling_rng, delta=0.1, num_steps=1000, noise_std=noise_std)

In [None]:
plt.hist(noisy_observations.flatten(), bins=20, density=True, alpha=0.5, color='C0', label='Noisy Observations')
plt.legend()
plt.grid()
plt.title('Histogram of Noisy Observations')
plt.show();

In [None]:
plt.hist(denoised_observations.flatten(), bins=20, density=True, alpha=0.5, color='C1', label='Denoised Observations')
plt.legend()
plt.grid()
plt.title('Histogram of Denoised Observations')
plt.show();