In [1]:
import jax
import jax.numpy as jnp
from jax import random, jit

# Define the target distribution (Gaussian distribution in this case)
def target_distribution(x):
    return jnp.exp(-0.5 * x ** 2)

# Define the Metropolis-Hastings step
@jit
def metropolis_step(key, x, target_log_prob):
    key, subkey = random.split(key)
    x_new = x + random.normal(subkey)  # Propose a new state from a normal distribution
    
    log_acceptance_ratio = target_log_prob(x_new) - target_log_prob(x)
    acceptance_prob = jnp.exp(log_acceptance_ratio)
    
    key, subkey = random.split(key)
    uniform_sample = random.uniform(subkey)
    
    x = jnp.where(uniform_sample < acceptance_prob, x_new, x)  # Accept or reject the new state
    return key, x

# JIT compiled Metropolis-Hastings sampler
@jit
def run_metropolis(key, x0, target_log_prob, n_steps):
    xs = []
    x = x0
    
    for _ in range(n_steps):
        key, x = metropolis_step(key, x, target_log_prob)
        xs.append(x)
    
    return jnp.array(xs)

# Set up the sampling
key = random.PRNGKey(42)
x0 = 0.0  # Initial state
n_steps = 10000

# Define the log of the target distribution
target_log_prob = jit(lambda x: -0.5 * x ** 2)

# Run the Metropolis sampler
samples = run_metropolis(key, x0, target_log_prob, n_steps)

# Print the first few samples
print(samples[:10])


TypeError: Cannot interpret value of type <class 'jaxlib.xla_extension.PjitFunction'> as an abstract array; it does not have a dtype attribute