# Tutorial 05: MCMC over v_turb on a synthetic RADJAX dataset

The tutorial illustrates how to:

1. Generate a synthetic observation from a parametric protoplanetary disk model.
2. Save the observation (spectral cube + metadata) to a FITS file.
3. Set up a simple MCMC sampler over the microturbulent velocity parameter `v_turb`.
4. Run [`emcee`](https://emcee.readthedocs.io/en/stable/) to recover `v_turb`.
5. Save all experimental settings and results to a YAML file for reproducibility.

We use the `broken_power_law` model implemented in RADJAX as our forward model.

In [1]:
import os
import jax
import numpy as np
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)  

from astropy.io import fits
import emcee
import yaml

from radjax import inference
from radjax import sensor
from radjax import utils
from radjax import chemistry as chem
from radjax.models import broken_power_law as disk_model
from radjax.models.broken_power_law import forward_model

## Load parameters and build baseline scene

We begin by loading disk, observation, and chemistry parameters, building the
baseline disk structure, and constructing rays and frequencies. These form the
static context of the SamplerState.

In [2]:
num_freqs = 20
npix      = 200
rng_seed_synth = 1234  # for JAX noise
numpy_seed     = 2025  # for emcee RNG

params_path  = "./tutorial01_params.yaml"   # produced in Tutorial 01/02

# load baseline params
disk_params = disk_model.disk_from_yaml(params_path)
obs_params  = sensor.params_from_yaml(params_path)
chem_params = chem.chemistry_from_yaml_path(params_path)
mol         = chem.load_molecular_tables(chem_params)

# precompute base disk grids
temperature, v_phi, co_nd, base_disk = disk_model.co_disk_from_params(
    disk_params, chem_params
)

# frequency grid
freqs = sensor.compute_camera_freqs(
    num_freqs=num_freqs,
    width_kms=obs_params.velocity_width_kms,
    nu0=mol.nu0
)

# sky grid and rays
xaxis = yaxis = np.linspace(-obs_params.fov/2, obs_params.fov/2, npix)
x_sky, y_sky = jnp.meshgrid(xaxis, yaxis, indexing="xy")
rays = sensor.rays_from_params(obs_params, x_sky, y_sky)

# SamplerState (static context)
state = inference.SamplerState(
    obs_params=obs_params,   
    disk_params=disk_params,        
    chem_params=chem_params,
    base_disk=base_disk,
    rays=rays,               
    mol=mol,      
    beam=None,
    use_pressure_correction=True,
)

## Define θ → DiskParams adapter
We now define a minimal adapter that maps a one-dimensional parameter vector θ = [v_turb] to a new `DiskParams` object. <br>
This isolates the **dynamic** parameters (sampled each MCMC step) from the **static** state.

We add this adapter to the state and also add the noise parameter `sigma`

In [3]:
class VTurbAdapter:
    def apply(self, disk_params, theta: jnp.ndarray):
        return disk_params.replace(v_turb=theta[0])

state = state.replace(
    sigma=10.0,               # Jy/pixel
    adapter=VTurbAdapter()
)

In [4]:
v_turb_true = 0.32 

disk_true = state.adapter.apply(state.disk_params, jnp.array([v_turb_true]))
cube_true = forward_model(disk_params=disk_true, freqs=freqs, state=state, output="image")

# add Gaussian noise (use the same sigma as in state)
jax_seed = 42
key = jax.random.PRNGKey(42)       # reproducible seed
noise = state.sigma * jax.random.normal(key, shape=cube_true.shape)
cube_obs = cube_true + noise

In [5]:
# save to FITS
fits_path = "./tutorial05_artifacts/synthetic_observation.fits"
utils.save_synthetic_observation(
    filepath=fits_path,
    cube_obs=cube_obs,
    freqs=freqs,
    nu0=mol.nu0,
    sigma=state.sigma,
    noise_type="jax.random.normal",
    seed=jax_seed,
)

✔️ Saved FITS → ./tutorial05_artifacts/synthetic_observation.fits


## Define prior, likelihood, and emcee target

We implement a flat prior on `v_turb`, a `JAX-JIT` forward model, and a Gaussian pixel likelihood. <br>
Together these form the log-posterior used by `emcee`.

In [6]:
vmin, vmax = 0.0, 2.0

def logprior(theta_np: np.ndarray) -> float:
    v = float(theta_np[0])
    return 0.0 if (vmin <= v <= vmax) else -np.inf

@jax.jit
def render_image_from_theta(theta_jnp: jnp.ndarray) -> jnp.ndarray:
    disk_state = state.adapter.apply(state.disk_params, theta_jnp)
    return forward_model(disk_params=disk_state, freqs=freqs, state=state, output="image")

def loglikelihood(theta_np: np.ndarray) -> float:
    cube_model = render_image_from_theta(jnp.asarray(theta_np))
    residual = (cube_model - cube_obs) / state.sigma
    return -0.5 * np.sum(residual * residual)

# emcee target
def logprob(theta_np: np.ndarray) -> float:
    lp = logprior(theta_np)
    if not np.isfinite(lp):
        return -np.inf
    return lp + loglikelihood(theta_np)

## Run MCMC with emcee

We now run an ensemble sampler with 32 walkers for 100 steps. <br> 
The sampler explores the posterior distribution of `v_turb`.

In [7]:
numpy_seed = 2025 
np.random.seed(numpy_seed)

ndim = 1
nwalkers = 10
nsteps   = 2

# Set up a backend
# Don't forget to clear it in case the file already exists
filename = "tutorial05_artifacts/emcee_state.h5"
backend = emcee.backends.HDFBackend(filename)
backend.reset(nwalkers, ndim)

p0 = np.random.uniform(vmin, vmax, size=(nwalkers, ndim))
sampler = emcee.EnsembleSampler(nwalkers, ndim, logprob, backend=backend)
emcee_state = sampler.run_mcmc(p0, nsteps, progress=True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:18<00:00,  9.15s/it]


## Summarize results and posterior predictive check

We extract the MAP and median estimates for `v_turb`, compute the 68%
credible interval, and compare the posterior predictive cube to the synthetic
observation.

In [8]:
chain = sampler.get_chain(flat=True)[:, 0]
logp  = sampler.get_log_prob(flat=True)

v_map = chain[np.argmax(logp)]
v_med = np.median(chain)
v_lo, v_hi = np.percentile(chain, [16, 84])

print(f"True v_turb: {v_turb_true:.2f}")
print(f"MAP  v_turb: {v_map:.2f} ")
print(f"MED  v_turb: {v_med:.2f} (68% CI: {v_lo:.1f}, {v_hi:.1f})")

# posterior predictive at MAP
disk_map = state.adapter.apply(state.disk_params, jnp.array([v_map]))
cube_map = forward_model(disk_params=disk_map, freqs=freqs, state=state, output="image")
rms = np.sqrt(np.mean((np.asarray(cube_map) - cube_obs)**2))
print(f"Posterior predictive RMS (MAP): {rms:.4f}")

True v_turb: 0.32
MAP  v_turb: 0.32 
MED  v_turb: 0.95 (68% CI: 0.3, 1.8)
Posterior predictive RMS (MAP): 9.9942


## Save MCMC configuration and results to YAML

Finally, we save all experimental parameters, priors, sampler settings,
and results into a compact YAML file for reproducibility.

In [9]:
mcmc_params = {
    "method": "emcee",
    "numpy_seed": numpy_seed,
    "nwalkers": nwalkers,
    "nsteps": nsteps,
    "sampled_params": ["v_turb"],
    "bounds": {"v_turb": [vmin, vmax]}
}

results = {
    "v_turb_true": v_turb_true,
    "v_turb_map": v_map,
    "v_turb_med": v_med,
    "v_turb_ci68": [v_lo, v_hi],
    "rms_map": rms
}

output_yaml_path = "tutorial05_artifacts/params-and-mcmc-results.yaml"
disk_model.params_to_yaml_path(disk_true, output_yaml_path)
chem.chemistry_to_yaml_path(chem_params, output_yaml_path)
inference.append_inference_to_yaml(output_yaml_path, mcmc_params, results)