---
title: Fitting a line
jupyter:
  jupytext:
    text_representation:
      extension: .qmd
      format_name: quarto
      format_version: '1.0'
      jupytext_version: 1.17.2
  kernelspec:
    display_name: Python 3
    language: python
    name: python3
---

The [Getting started](./getting-started.ipynb) shows how to sample a 3D gaussian with `simple`.
In this tutorial, we will build on this to demonstrate a more realistic scenario where we fit a line to data.

## Simulated data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import loguniform

rng = np.random.default_rng(123)

x = np.sort(10 * rng.random(100))
m_true = 1.338
b_true = -0.45
truths = {"m": m_true, "b": b_true, "sigma": None}
y_true = m_true * x + b_true
yerr = 0.1 + 0.5 * rng.random(x.size)
y = y_true + 2 * yerr * rng.normal(size=x.size)

ax = plt.gca()
ax.plot(x, y_true, label="True signal")
ax.errorbar(x, y, yerr=yerr, fmt="k.", capsize=2, label="Simulated data")
ax.set_ylabel("y")
ax.set_xlabel("x")
plt.legend()
plt.show()

## Linear model

In the [getting started](./getting-started.ipynb) tutorial, we did not have a forward model: the likelihood was directly specified as a distribution.
In most physical scenarios, we have some sort of forward model, which is then called by our likelihood.
Let's define these two functions.

In [None]:
def forward_model(parameters, x):
    m, b = parameters["m"], parameters["b"]
    return m * x + b

def log_likelihood(parameters, x, y, yerr):
    mu = forward_model(parameters, x)
    sigma = np.sqrt(parameters["sigma"]**2 + yerr**2)
    return -0.5 * np.sum(((y - mu) / sigma) ** 2 + np.log(2 * np.pi * sigma**2))

We can use the log-likelihood to create a `simple` model.
We can optionally pass the forward model so that the `Model` object is aware of it.
This allows you to pass a vector or a dictionary, and to easily generate predictive samples.

In [None]:
from scipy.stats import uniform, norm
from simple.distributions import ScipyDistribution
# TODO: Replace with cleaner import once no longer need to reload
from importlib import reload
import simple.model as sm
reload(sm)

parameters = {
    "m": ScipyDistribution(uniform(-10, 20)),
    "b": ScipyDistribution(uniform(-10, 20)),
    "sigma": ScipyDistribution(loguniform(1e-5, 100)),
}

model = sm.Model(parameters, log_likelihood)

And we can check that the model works as expected.

In [None]:
test_point = {"m": 1.0, "b": 0, "sigma": 1.0}
print("Log prior", model.log_prior(test_point))
print("Log likelihood", model.log_likelihood(test_point, x, y, yerr))
print("Log probability", model.log_prob(test_point, x, y, yerr))

In [None]:
ax = plt.gca()
ax.plot(x, y_true, label="True signal")
ax.plot(x, forward_model(test_point, x), label="Test model")
ax.errorbar(x, y, yerr=yerr, fmt="k.", capsize=2, label="Simulated data")
ax.set_ylabel("y")
ax.set_xlabel("x")
plt.legend()
plt.show()

## Prior checks

In [None]:
import corner

n_prior = 1000
prior_samples = model.get_prior_samples(n_prior)

fig = corner.corner(prior_samples)
fig.suptitle("Prior samples")
plt.show()

In [None]:
n_pred = 100
rng = np.random.default_rng()
show_idx = rng.choice(n_prior, n_pred, replace=False)
ax = plt.gca()
for i in show_idx:
    ypred = forward_model({k: prior_samples[k][i] for k in model.keys()}, x)
    ax.plot(x, ypred, "C1-", label="Prior samples" if i == show_idx[0] else None, alpha=0.1)
ax.plot(x, y_true, label="True signal")
ax.errorbar(x, y, yerr=yerr, fmt="k.", capsize=2, label="Simulated data")
ax.set_ylabel("y")
ax.set_xlabel("x")
ax.set_title("Prior predictive samples")
plt.legend()
plt.show()

## Sampling

In [None]:
import zeus

nwalkers = 100
nsteps = 1000
ndim = len(model.keys())
start = np.array([0.0, 0.0, 10.0]) + rng.standard_normal(size=(nwalkers, ndim))
sampler = zeus.EnsembleSampler(nwalkers, ndim, model.log_prob, args=(x, y, yerr))
sampler.run_mcmc(start, nsteps)

## Posterior distribution and predictions

In [None]:
from simple.plot import chainplot

chains = sampler.get_chain()
chainplot(chains, labels=model.keys())
plt.show()

In [None]:
flat_chains = sampler.get_chain(discard=200, flat=True, thin=5)
corner.corner(
    flat_chains,
    labels=model.keys(),
    truths=list(truths.values()),
)
plt.show()

In [None]:
n_pred = 100
rng = np.random.default_rng()
show_idx = rng.choice(flat_chains.shape[0], n_pred, replace=False)
ax = plt.gca()
for i in show_idx:
    ypred = forward_model(dict(zip(model.keys(), flat_chains[i],strict=True)), x)
    ax.plot(x, ypred, "C1-", label="Posterior samples" if i == show_idx[0] else None, alpha=0.1)
ax.plot(x, y_true, label="True signal")
ax.errorbar(x, y, yerr=yerr, fmt="k.", capsize=2, label="Simulated data")
ax.set_ylabel("y")
ax.set_xlabel("x")
ax.set_title("Posterior predictive samples")
plt.legend()
plt.show()