# Sample from Posterior

_scarlet2_ can provide samples from the posterior distribution to pass to downstream operations and as the most precise option for uncertainty quantification. In principle, we can get posterior samples for every parameter, and this can be done with any sampler by evaluating the log-posterior distribution. For this guide we will use the Hamiltonian Monte Carlo sampler from numpyro, for which we created a convenient front-end in _scarlet2_.

We start from the [quickstart tutorial](../0-quickstart), loading the same data and the best-fitting model.

In [None]:
# Import Packages and setup
import jax.numpy as jnp
import matplotlib.pyplot as plt

from scarlet2 import *

## Create Observation

We need to create the {py:class}`~scarlet2.Observation` because it contains the {py:func}`~scarlet2.Observation.log_likelihood` method we need for the posterior:

In [None]:
# load the data
from huggingface_hub import hf_hub_download

filename = hf_hub_download(
    repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
)
file = jnp.load(filename)
data = jnp.asarray(file["images"])
channels = [str(f) for f in file["filters"]]
centers = jnp.array([(src["y"], src["x"]) for src in file["catalog"]])
weights = jnp.asarray(1 / file["variance"])
psf = jnp.asarray(file["psfs"])

# create the observation
obs = Observation(
    data,
    weights,
    psf=ArrayPSF(psf),
    channels=channels,
)

## Load Model

We can make use of the best-fit model from the Quickstart guide as the starting point of the sampler.

In [None]:
import scarlet2.io

id = 35
filename = "hsc_cosmos.h5"
scene = scarlet2.io.model_from_h5(filename, path="..", id=id)

Let's have a look:

In [None]:
norm = plot.AsinhAutomaticNorm(obs)
plot.scene(scene, observation=obs, norm=norm, add_boxes=True)
plt.show()

## Define Parameters with Prior

In principle, we can get posterior samples for every parameter. We will demonstrate by sampling from the spectrum and the center position of the point source #0. We therefore need to set the `prior` attribute for each of these parameters; the attribute `stepsize` is ignored, but `constraint` cannot be used when `prior` is set.

In [None]:
import numpyro.distributions as dist

C = len(channels)
parameters = scene.make_parameters()

# rough guess of source brightness across bands
p1 = scene.sources[0].spectrum
prior1 = dist.Uniform(low=jnp.zeros(C), high=500 * jnp.ones(C))
parameters += Parameter(p1, name="spectrum", prior=prior1)

# initial position was integer pixel coordinate
# assume 0.5 pixel uncertainty
p2 = scene.sources[0].center
prior2 = dist.Normal(centers[0], scale=0.5)
parameters += Parameter(p2, name="center", prior=prior2)

```{warning}
You are responsible to set reasonable priors, which describe what you know about the parameter before having looked at the data. In the example above, the spectrum gets a wide flat prior, and the center prior uses the position `centers[0]`, which is given by the original detection catalog. Neither use information from the optimized `scene`.

Also: If in doubt how much prior choices matter, vary them within reason.
```

## Run Sampler

Then we can run numpyro's {py:class}`~numpyro.infer.hmc.NUTS` sampler with a call to {py:func}`~scarlet2.Scene.sample`, which is analogous to {py:func}`~scarlet2.Scene.fit`:

In [None]:
mcmc = scene.sample(
    obs,
    parameters,
    num_warmup=100,
    num_samples=1000,
    progress_bar=False,
)
mcmc.print_summary()

## Access Samples

The samples can be accessed from the MCMC chain and are listed as arrays under the names chosen above for the respective `Parameter`.

In [None]:
import pprint

samples = mcmc.get_samples()
pprint.pprint(samples)

To create versions of the scene for any of the samples, we first select a few at random and then use the method {py:func}`scarlet2.Module.replace` to set their values at the locations identified by `parameters`:

In [None]:
# get values for three random samples
S = 3
import jax.random

seed = 42
key = jax.random.key(seed)
idxs = jax.random.randint(key, shape=(S,), minval=0, maxval=mcmc.num_samples)

values = [[spectrum, center] for spectrum, center in zip(samples["spectrum"][idxs], samples["center"][idxs])]

# create versions of the scene with these posterior samples
scenes = [scene.replace(parameters, v) for v in values]

# display the source model
fig, axes = plt.subplots(1, S, figsize=(10, 4))
for s in range(S):
    source_array = scenes[s].sources[0]()
    axes[s].imshow(plot.img_to_rgb(source_array, norm=norm))

The difference are imperceptible for this source which tells us that the data were highly informative. But we can measure e.g. the total fluxes for each sample

In [None]:
print(f"-------------- {channels}")
for i, scene in enumerate(scenes):
    print(f"Flux Sample {i}: {measure.flux(scene.sources[0])}")

## Visualize Posterior

We can also visualize the posterior distributions, e.g. with the [`corner`](https://corner.readthedocs.io/en/latest/) package:

In [None]:
import corner

corner.corner(mcmc);