In [None]:
import pickle as pkl

import corner
import matplotlib.pyplot as plt
import numpy as np
from getdist import MCSamples, plots
from pyoperators import MPI
from qubic.lib.Qfoldertools import  create_folder_if_not_exists
from qubic.lib.Qmpi_tools import MpiTools

In [None]:
mcmc_file = pkl.load(open("/pbs/home/t/tlaclave/sps/qubic/qubic/scripts/MapMaking/src/FMM/test_bis/Fit/test_qubic_only/mcmc.pkl", "rb"))

In [None]:
mcmc_file["parameters"]

# Parameters

In [None]:
samples = mcmc_file["samples"]
samples_flat = mcmc_file["samples_flat"]
parameters = mcmc_file["parameters"]
fit_params = mcmc_file["fit_parameters"]

fit_parameters_names = ["r"]

print("samples", samples.shape)
print("samples_flat", samples_flat.shape)

In [None]:
comm = MPI.COMM_WORLD
mpi = MpiTools(comm)

### Folder where spectrum should be stored
folder = fit_params["simulation_path"]
folder_spectrum = folder + "/Spectrum/"
folder_save = folder + "/Fit/" + fit_params["name"] + "/"
create_folder_if_not_exists(comm, folder_save)

### Import Spectrum parameters
# TODO: should we really modify these values
nbins = fit_params["Spectrum"]["nbins"]
sample_variance = fit_params["Spectrum"]["sample_variance"]
diagonal = fit_params["Spectrum"]["diagonal"]
# TODO: compute fsky from coverage
fsky = fit_params["Spectrum"]["fsky"]
dl = fit_params["Spectrum"]["dl"]

### Import MCMC parameters
# Frequency used for fitting
nus_qubic = fit_params["MCMC"]["nus_qubic"]
nus_planck = fit_params["MCMC"]["nus_planck"]
nus_index = np.array(nus_qubic + nus_planck)

discard = fit_params["MCMC"]["discard"]
nwalkers = fit_params["MCMC"]["nwalkers"]
nsteps = fit_params["MCMC"]["nsteps"]
verbose = fit_params["MCMC"]["verbose"]

# MCMC

In [None]:
n_params = samples.shape[-1]

# Handle case: 1 parameter → axes is a single object
fig, axes = plt.subplots(
    n_params, 1, figsize=(12, 3 * n_params)
)

if n_params == 1:
    axes = [axes]  # wrap in list for uniform iteration

fig.suptitle("MCMC Chain Evolution", fontsize=14)

for iparam, ax in enumerate(axes):
    # Plot all walkers
    for walker in range(nwalkers):
        ax.plot(samples[:, walker, iparam], alpha=0.5, color="k")

    # Compute stats
    mean = np.mean(samples[..., iparam])
    std  = np.std(samples[..., iparam])

    # Mean and ±1σ lines
    ax.axhline(mean, color="r", linestyle="--", label=f"Mean: {mean:.3f}")
    ax.axhline(mean + std, color="b", linestyle=":", label=f"±1σ: {std:.3f}")
    ax.axhline(mean - std, color="b", linestyle=":")

    # Burn-in line
    if discard > 0:
        ax.axvline(discard, color="g", linestyle="--", alpha=0.5, label="Burn-in")

    # Labels & formatting
    ax.set_ylabel(fit_parameters_names[iparam])
    ax.set_xlim(0, nsteps)
    ax.grid(True, alpha=0.3)

    # Put legend outside
    ax.legend(loc="right", bbox_to_anchor=(1.25, 0.5))

    if iparam == n_params - 1:
        ax.set_xlabel("Step number")

plt.tight_layout()
plt.show()

# Triangle Plots

In [None]:
### Triangle Plots
means = np.mean(samples_flat, axis=0)
stds = np.std(samples_flat, axis=0)
n_params = len(fit_parameters_names)

## Corner

In [None]:
if n_params == 1:
    # Corner cannot do a triangle with one param → do 1D posterior
    fig = plt.figure(figsize=(5, 4))
    ax = fig.add_subplot(111)
    ax.hist(samples_flat[:, 0], bins=40, density=True, alpha=0.6)
    ax.axvline(means[0], linestyle="--", lw=1.5)
    ax.axvline(means[0] - stds[0], linestyle=":", lw=1)
    ax.axvline(means[0] + stds[0], linestyle=":", lw=1)
    ax.set_xlabel(fit_parameters_names[0])
    ax.set_ylabel("Density")
    ax.text(0.05, 0.95, rf"$\mu={means[0]:.3f}$"+"\n"+rf"$\sigma={stds[0]:.3f}$",
            transform=ax.transAxes, ha="left", va="top", bbox=dict(facecolor="white", alpha=0.8))
    fig.suptitle("Posterior — Corner (1D)", fontsize=16, fontweight="bold", y=1.01)
else:
    corner_kwargs = dict(
        labels=[rf"${p}$" for p in fit_parameters_names],
        show_titles=True,
        title_fmt=".3f",
        quantiles=[0.16, 0.5, 0.84],
        title_kwargs={"fontsize": 12},
        label_kwargs={"fontsize": 14},
        color="black",
        hist_kwargs={"density": True, "lw": 1.5},
        smooth=1.0,
    )

    fig = corner.corner(samples_flat, **corner_kwargs)
    fig.suptitle("Posterior Triangle — Corner", fontsize=16, fontweight="bold")

    axes = np.array(fig.axes).reshape(n_params, n_params)
    for i, (mean, std) in enumerate(zip(means, stds)):
        ax = axes[i, i]
        txt = rf"$\mu = {mean:.3f}$" + "\n" + rf"$\sigma = {std:.3f}$"
        ax.text(0.05, 0.95, txt,
                transform=ax.transAxes, ha="left", va="top",
                fontsize=9, bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"))
plt.show()

## GetDist

In [None]:
gds_samples = MCSamples(
    samples=samples_flat,
    names=fit_parameters_names,
    labels=[p.replace("_", r"\_") for p in fit_parameters_names],
    settings={"smooth_scale_1D": 0.5, "smooth_scale_2D": 0.6,
              "mult_bias_correction_order": 0},
)

g = plots.get_single_plotter() if n_params == 1 else plots.get_subplot_plotter()

if n_params == 1:
    g.plot_1d(gds_samples, fit_parameters_names[0], filled=True)
    ax = g.fig.axes[0]
    ax.axvline(means[0], linestyle="--", lw=1.5)
    ax.axvline(means[0] - stds[0], linestyle=":", lw=1)
    ax.axvline(means[0] + stds[0], linestyle=":", lw=1)
    ax.text(0.5, 1.0, rf"${means[0]:.3f} \pm {stds[0]:.3f}$",
            transform=ax.transAxes, ha="center", va="bottom",
            fontsize=12)
    plt.suptitle("Posterior — GetDist", fontsize=16, fontweight="bold", y=1.05)

else:
    g.settings.num_plot_contours = 2
    g.settings.alpha_filled_add = 0.6
    g.settings.linewidth = 2
    g.settings.figure_legend_frame = False
    g.settings.axes_fontsize = 12
    g.settings.axes_labelsize = 14
    g.settings.legend_fontsize = 12
    g.settings.title_limit_fontsize = 12

    g.triangle_plot(gds_samples, params=fit_parameters_names, filled=True)

    axes = g.subplots
    for i in range(n_params):
        for j in range(i + 1):
            ax = axes[i, j]
            if i == j:
                ax.axvline(means[j], linestyle="--", lw=1.5)
                ax.axvline(means[j] - stds[j], linestyle=":", lw=1)
                ax.axvline(means[j] + stds[j], linestyle=":", lw=1)
                ax.text(0.5, 1.02, rf"${means[j]:.3f} \pm {stds[j]:.3f}$",
                        transform=ax.transAxes, ha="center", va="bottom",
                        fontsize=10, bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"))
            else:
                ax.axhline(means[i], linestyle="--", lw=1.5)
                ax.axhline(means[i] - stds[i], linestyle=":", lw=1)
                ax.axhline(means[i] + stds[i], linestyle=":", lw=1)
                ax.axvline(means[j], linestyle="--", lw=1.5)
                ax.axvline(means[j] - stds[j], linestyle=":", lw=1)
                ax.axvline(means[j] + stds[j], linestyle=":", lw=1)

    plt.suptitle("Posterior Triangle — GetDist", fontsize=16, fontweight="bold", y=1.05)
plt.show()