# Autocalibration ideas based on current implementation

In [None]:
%matplotlib inline

import glob
import os
from collections.abc import Callable, Sequence, Iterator, Mapping
from abc import ABC, abstractmethod
from typing import Any

import matplotlib.pyplot as plt
import hist

import numpy as np
import awkward as ak

import scipy

from lgdo import lh5
from lgdo.lh5.exceptions import LH5DecodeError
from legendmeta import LegendMetadata
from dspeed.processors import get_multi_local_extrema

plt.rcParams["figure.figsize"] = (10, 4)

proj_dir = "/mnt/atlas02/projects/legend/sipm_qc"
lmeta  = LegendMetadata(os.path.join(proj_dir, "metadata/legend-metadata-schwarz"))
chmap = lmeta.channelmap("20250807T150028Z")
chmap_sipm = chmap.map("system", unique=False).spms
#requires recent legend-datasets
raw_keys = chmap_sipm.map("analysis.usability", unique=False).on.map("daq.rawid").keys()

In [None]:
raw_dir = os.path.join(proj_dir, "data/tier/raw/phy/p15/r004_part")
dsp_dir = os.path.join(proj_dir, "manual_dsp/generated/p15r004dsp_part")
orig_dsp_dir = os.path.join(proj_dir, "data/tier/dsp/phy/p15/r004")
dsp_files = glob.glob(dsp_dir+"/l200-*-tier_dsp.lh5")
dsp_files.sort()
def gimme_raw_filename_from_dsp(dspfilename: str):
    return dspfilename.replace(dsp_dir, raw_dir).replace("tier_dsp", "tier_raw")
def gimme_orig_dsp_filename(dspfilename: str):
    # get the original dsp files so I get pulser info
    return dspfilename.replace(dsp_dir, orig_dsp_dir)

In [None]:
def get_nopulser_mask(orig_dsp_file: Sequence[str] | str) -> ak.Array:
    trap_puls = lh5.read_as(f"ch{chmap['PULS01'].daq.rawid}/dsp/trapTmax", orig_dsp_file, "ak")
    return trap_puls < 100

#trap_puls = lh5.read_as(f"ch{chmap['PULS01'].daq.rawid}/dsp/trapTmax", f_dsp, "np")
#selection = trap_puls < 100
#idx_not_pulser = np.where(selection)[0]
#idx_pulser = np.where(~selection)[0]

In [None]:
def plot_some_wfs(wfs, ax, num=10):
    
    if wfs.shape[0] < num:
        num = wfs.shape[0]
        
    t = np.arange(0, wfs.shape[1]*0.016, 0.016)
    
    for i in range(num):
        t = np.arange(0, len(wfs[i])*0.016, 0.016)
        ax.plot(t, wfs[i])
        
    ax.set_xlabel("Time (µs)")
    ax.set_ylabel("ADC counts / sample")

# Check uncalibrated spectra

In [None]:
def get_energies(dsp_file: Sequence[str] | str, keys: Iterator[int], chmap, *, 
                 orig_dsp_file: Sequence[str] | str | None = None):
    """if orig_dsp_file is given: remove the pulser based on get_nopulser_mask"""
    keys: list[int] = list(keys) # I need to access element 0 separately

    def get_energy_object_name_function(dsp_file: str, raw_key: int, name_key: str) -> Callable[[int, str], str]:
        if not os.path.isfile(dsp_file):
            raise RuntimeError(f"ERROR: no file: {dsp_file}")
        fcns = [lambda rawid, name: f"ch{rawid}/dsp/energy", lambda rawid, name: f"ch{rawid}/dsp/energies",
                lambda rawid, name: f"{name}/dsp/energy", lambda rawid, name: f"{name}/dsp/energies"]
        for fcn in fcns:
            try:
                _ = lh5.read_as(fcn(raw_key, name_key), dsp_file, "ak")
                return fcn
            except LH5DecodeError:
                continue
        raise RuntimeError("Have no clue how to extract energy info")

    energy_object_name_fcn = get_energy_object_name_function(dsp_file if isinstance(dsp_file, str) else dsp_file[0], keys[0], chmap.map("daq.rawid")[keys[0]].name)
    energies_dict = {}
    #print(f"{len(keys)} keys in dsp files")
    for ch in keys:
        name = chmap.map("daq.rawid")[ch].name
        #energy = lh5.read_as(f"{name}/dsp/energy", f_dsp, "ak")
        energy = lh5.read_as(energy_object_name_fcn(ch, name), dsp_file, "ak")
        # remove pulser if we have original DSP files (containing pulser info)
        # TODO perf: cache nopulser_mask
        if orig_dsp_file is not None:
            nopulser_mask = get_nopulser_mask(orig_dsp_file)
            if len(nopulser_mask) < len(energy):
                raise RuntimeError("Nopulser mask too short")
            elif len(nopulser_mask) > len(energy):
                nopulser_mask = nopulser_mask[:len(energy)]
            energy = energy[nopulser_mask]

        energies = np.array(ak.flatten(energy))

        energies_dict[name] = energies
        
    energies_dict = dict(sorted(energies_dict.items()))
    
    return energies_dict


def gen_hist_by_quantile(data, quantile=0.99, nbins=200):
    bins = np.linspace(0, np.round(np.quantile(data, quantile)), nbins+1)
    n, be = np.histogram(data, bins)
    return n, be

def gen_hist_by_range(data, range, nbins=200):
    n, be = np.histogram(data, range=range, bins=nbins)
    return n, be

In [None]:
def plot_all_pe_spectra(energies_dict):
    
    fig, ax = plt.subplots(10, 6, figsize=(20,20))
    ax = ax.ravel()

    for i, (name, data) in enumerate(energies_dict.items()):

        n, be = gen_hist_by_quantile(data, 0.96)
        ax[i].stairs(n, be)
            
        ax[i].set_yscale("log")
        ax[i].set_title(name, fontsize=10)

    fig.tight_layout()

In [None]:
energies = get_energies(dsp_files, raw_keys, chmap, orig_dsp_file=[gimme_orig_dsp_filename(dsp) for dsp in dsp_files])

In [None]:
fig, ax = plt.subplots()
ax.set_yscale('log')
ax.stairs(*gen_hist_by_range(get_energies(dsp_files, raw_keys, chmap)["S002"], (0,50)))
ax.stairs(*gen_hist_by_range(get_energies(list(map(gimme_orig_dsp_filename, dsp_files)), raw_keys, chmap)["S002"], (0,50)))

In [None]:
#plot_all_pe_spectra(get_energies(gimme_orig_dsp_filename(dsp_files[0]), list(raw_keys), chmap))
#plot_all_pe_spectra(get_energies(list(map(gimme_orig_dsp_filename, dsp_files)), list(raw_keys), chmap))
plot_all_pe_spectra(energies)

# Simple calibration

In [None]:
# Default inputs
peakfinder_defaults = {
    "a_delta_min_in": 5e-3,
    "a_delta_max_in": 5e-3,
    "search_direction": 3,
    "a_abs_min_in": 1000,
    "a_abs_max_in": 1e-4,
    "min_peak_dist": 6,
    "peakdist_compare_margin": 1,
    "strict": True
}
    
def find_pe_peaks_in_hist(n, be, params: Mapping[str, Any]) -> np.typing.NDArray[np.int_]:
    n = np.array(n)
    be = np.array(be)

    # Outputs
    vt_max_out = np.zeros(shape=len(n) - 1)
    vt_min_out = np.zeros(shape=len(n) - 1)
    n_max_out = 0
    n_min_out = 0

    # Call the function with updated parameters
    get_multi_local_extrema(
        n,
        params["a_delta_max_in"] * np.max(n),
        params["a_delta_min_in"] * np.max(n),
        params["search_direction"],
        params["a_abs_max_in"] * np.max(n),
        params["a_abs_min_in"] * np.max(n),
        vt_max_out,
        vt_min_out,
        n_max_out,
        n_min_out,
    ) # type: ignore

    peakpos_indices = vt_max_out[~np.isnan(vt_max_out)].astype(np.int_)
    return peakpos_indices

class ResultCheckError(ValueError):
    def __init__(self, *args):
        super().__init__(*args)

def check_and_improve_PE_peaks(
        peakpos_indices: np.typing.NDArray[np.int_], 
        n: np.typing.NDArray[Any],
        be: np.typing.NDArray[Any],
        params: Mapping[str, Any]
        ) -> np.typing.NDArray[np.int_]:

    if params.get("double_peak", False):
        if len(peakpos_indices) < 3:
            raise ResultCheckError(f"Require at least 3 found peaks for double-peak SiPMs; found only {len(peakpos_indices)}.")
        mean_1_2 = (peakpos_indices[1] + peakpos_indices[2]) // 2 # small bias but ok
        peakpos_indices = np.concatenate((np.array([peakpos_indices[0], mean_1_2], dtype=np.int_), peakpos_indices[3:]))

    min_peak_dist = params["min_peak_dist"]
    if params["strict"]:
        if len(peakpos_indices) < 2:
            raise ResultCheckError(f"Only {len(peakpos_indices)} peaks found. Either noise peak or 1pe peak not found.")

        if n[peakpos_indices[1]] > n[peakpos_indices[0]]:
            raise ResultCheckError(f"1pe peak larger than noise peak.")

        if peakpos_indices[1] - peakpos_indices[0] < min_peak_dist:
            raise ResultCheckError(f"Noise peak and 1pe peak too close together (< {min_peak_dist} bins).")

        if len(peakpos_indices) > 2:
            if peakpos_indices[2] - peakpos_indices[1] < min_peak_dist:
                raise ResultCheckError(f"1pe peak and 2pe peak too close together (< {min_peak_dist} bins).")
            if (peakpos_indices[2] - peakpos_indices[1] + params["peakdist_compare_margin"]) < peakpos_indices[1] - peakpos_indices[0]:
                raise ResultCheckError(f"Distance between 1pe and 2pe smaller than 0pe and 1pe (outside peakdist_compare_margin of {params["peakdist_compare_margin"]}).")
    else:
        if len(peakpos_indices) > 2:
            if peakpos_indices[2] - peakpos_indices[1] < min_peak_dist:
                print(f"1pe peak and 2pe peak too close together (< {min_peak_dist} bins). Removing '1pe' peak.")
                peakpos_indices = peakpos_indices[peakpos_indices != peakpos_indices[1]]
    return  peakpos_indices

In [None]:
def simple_calibration(energies, gen_hist_params: Mapping[str, Any], 
                       peakfinder_params: Mapping[str, Any],
                       calibration_params: Mapping[str, Any], *, 
                       ax = None, verbosity = 0) -> dict[str, float]:
    """Generate histogram and perform a simple peakfinder-based calibration. 
    If an axis is provided: plot on that (otherwise don't plot)
    Does this for 1 SiPM; i.e. energies has to be a 1-d array of energies"""
    match gen_hist_params:
        case {"quantile": quantile, "nbins": nbins}:
            n, be = gen_hist_by_quantile(energies, quantile, nbins)
        case {"range": r, "nbins": nbins}:
            n, be = gen_hist_by_range(energies, r, nbins)
        case _:
            raise TypeError("gen_hist_params does not match valid histogram type")
    peakpos_indices = find_pe_peaks_in_hist(n, be, peakfinder_params)
    failed_checks = False
    try:
        peakpos_indices = check_and_improve_PE_peaks(peakpos_indices, n, be, peakfinder_params)  # might raise in case of failures
    except ResultCheckError:
        failed_checks = True
        peaks = be[peakpos_indices] # use old indices for peaks
        raise # runs finally before raise
    else: 
        peaks = be[peakpos_indices]

        if len(peaks) > 2: # use 1PE and 2PE
            gain = peaks[2] - peaks[1]
            c = 1/gain
            offset = 1 - peaks[1] * c # 1pe peak at 1
        else: # fallback: use 0 PE and 1 PE
            gain = peaks[1] - peaks[0]
            c = 1/gain
            offset = 1 - peaks[1] * c # 1pe peak at 1   
        
        return {"slope": c, "offset": offset} # runs finally before return
    finally: # draw in any case (for debugging); but choose color
         if ax is not None:
            hist_color = "red" if failed_checks else "blue"
            line_color = "grey" if failed_checks else "red"
            # uncalibrated histogram and peaks
            ax.stairs(n, be, color=hist_color)
            for p in peaks: # type: ignore
                ax.axvline(x=p, color=line_color, ls=":")



In [None]:
def multi_simple_calibration(energies_dict, 
                             gen_hist_defaults: dict[str, Any], 
                             peakfinder_defaults: dict[str, Any],
                             calibration_defaults: dict[str, Any], *,
                             gen_hist_overrides: dict[str, dict[str, Any]] = {}, 
                             peakfinder_overrides: dict[str, dict[str, Any]] = {},
                             calibration_overrides: dict[str, dict[str, Any]] = {},
                             draw = False, 
                             nodraw_axes = False, 
                             verbosity = 0
                             ) -> dict[str, dict[str, float]]:
    """Performs simple_calibration for all channels present in energies_dict"""
    ret = {}
    if draw:
        fig, ax = plt.subplots(10, 6, figsize=(20,20))
        ax_iter = iter(ax.ravel())
    nr_unsuccessful_calibs = 0
    for name, energies in energies_dict.items():
        if draw:
            ax = next(ax_iter)
            ax.set_yscale("log")
            if nodraw_axes:
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            ax.set_title(name, fontsize=10)
        else:
            ax = None
        try:
            calib_results = simple_calibration(
                energies,
                gen_hist_defaults | gen_hist_overrides.get(name, {}),
                peakfinder_defaults | peakfinder_overrides.get(name, {}),
                calibration_defaults | calibration_overrides.get(name, {}),
                ax=ax,  verbosity=verbosity)
            ret[name] = calib_results
        except ResultCheckError as e:
            print(f"Calibration failed for {name}: {e}")
            ret[name] = {"slope": np.nan, "offset": np.nan}
            nr_unsuccessful_calibs += 1
    
    if nr_unsuccessful_calibs > 0 and verbosity >= -1:
        print(f"WARNING: {nr_unsuccessful_calibs} calibrations failed!")
    elif verbosity >= 0:
        print("Info: All calibrations successful! :)")

    if draw:
        fig.tight_layout()
        if nodraw_axes: # have to do this also for non-drawn plots
            for ax in ax_iter:
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            fig.subplots_adjust(wspace=0) # , hspace=0)
    return ret

In [None]:
fig, ax = plt.subplots()
ax.set_yscale('log')
simple_calibration(energies["S100"], {"quantile": 0.99, "nbins": 200}, peakfinder_defaults,
                   {}, ax=ax)

In [None]:
fig, ax = plt.subplots()
ax.set_yscale('log')
simple_calibration(energies["S002"], {"quantile": 0.98, "nbins": 200}, peakfinder_defaults, {}, ax=ax)

In [None]:
peakfinder_overrides={
    "S046": {"a_delta_min_in": 8e-3, "a_delta_max_in": 8e-3,},
    "S015": {"a_delta_min_in": 1e-2, "a_delta_max_in": 2e-2,},
    "S083": {"double_peak": True},
    "S090": {"double_peak": True},
    "S095": {"double_peak": True},
    "S096": {"double_peak": True, "a_delta_min_in": 1e-3, "a_delta_max_in": 1e-3,},
    "S098": {"double_peak": True},
}

In [None]:
simple_calib_output = multi_simple_calibration(energies,  {"quantile": 0.98, "nbins": 200}, peakfinder_defaults, {}, peakfinder_overrides=peakfinder_overrides, draw=True, nodraw_axes=True)

In [None]:
def get_calibrated_histograms(energies, calib_output, range: tuple[float, float], nbins:int):
    ret: dict[str, dict[str, np.typing.NDArray[Any]]] = {}
    for name, energy in energies.items():
        if name not in calib_output:
            continue
        c = calib_output[name]["slope"]
        offset = calib_output[name]["offset"]
        if np.isnan(c) or np.isnan(offset):
            continue
        energy_calibrated = energy * c + offset
        n, be = gen_hist_by_range(energy_calibrated, range, nbins)
        ret[name] = {"n": n, "be": be}
    return ret

In [None]:
calibrated_histos = get_calibrated_histograms(energies, simple_calib_output, (0, 5), 200)
fig, ax = plt.subplots(10, 6, figsize=(20, 20))
ax = ax.ravel()
for i, (name, histo) in enumerate(calibrated_histos.items()):
    ax[i].set_yscale('log')
    ax[i].stairs(histo["n"], histo["be"])
    ax[i].set_title(name)
fig.tight_layout()
plt.show()

In [None]:
class ModelParameter:
    def __init__(self, init: tuple[float,float,float] | float):
        if isinstance(init, tuple):
            self.init = init[0]
            self.min = init[1]
            self.max = init[2]
        else:
            self.init = init
            self.min = -np.inf
            self.max = np.inf
        self.result: np.float64 = np.float64(np.nan)
    def set_result(self, result):
        self.result = result

class ModelComponent(ABC):
    def __init__(self, params: dict[str, ModelParameter]):
        self.params = params
    def nr_params(self) -> int:
        return len(self.params)
    @abstractmethod
    def eval(self, x, params) -> np.float64:
        pass
    def set_result_params(self, params):
        for res, par in zip(params, self.params.values()):
            par.set_result(res)
    def get_result_params(self) -> Sequence[np.float64]:
        return [p.result for p in self.params.values()]
        
class Gauss(ModelComponent):
    def __init__(self, mean, sigma, scale):
        super().__init__({
            "mean": ModelParameter(mean),
            "sigma": ModelParameter(sigma),
            "scale": ModelParameter(scale)
        })
    def eval(self, x, params):
        return params[2] * scipy.stats.norm.pdf(x, loc=params[0], scale=params[1])
    
class ExpoDec(ModelComponent):
    def __init__(self, lamb, scale):
        super().__init__({
            "lamb": ModelParameter(lamb),
            "scale": ModelParameter(scale)
        })
    def eval(self, x, params):
        return params[1]*np.exp(-1*params[0]*x)
    
class TwoHyperbole(ModelComponent):
    def __init__(self, p0, p1, p2):
        super().__init__({
            "p0": ModelParameter(p0),
            "p1": ModelParameter(p1),
            "p2": ModelParameter(p2),
        })
    def eval(self, x, params):
        return params[0] + params[1]/x + params[2]/(x*x)
    
class Linear(ModelComponent):
    def __init__(self, p0, p1):
        super().__init__({
            "p0": ModelParameter(p0),
            "p1": ModelParameter(p1),
        })
    def eval(self, x, params):
        return params[0] + params[1]*x
    
class SumModel(ModelComponent):
    def __init__(self, components: dict[str, ModelComponent]):
        passed_params = {}
        for m_name, model in components.items():
            for p_name, param in model.params.items():
                passed_params[m_name+"."+p_name] = param
        super().__init__(passed_params)
        self.components = components
    def eval(self, x, params) -> np.float64:
        curr = 0
        ret = np.float64(0)
        for comp in self.components.values():
            ret += comp.eval(x, params[curr:curr+comp.nr_params()])
            curr += comp.nr_params()
        return ret

def evaluate(model_component: ModelComponent, x, param_values: list[Any]):
    if not isinstance(param_values, list):
        raise ValueError("param_values has to be a list so it can be modified")
    ret = model_component.eval(x, param_values[:model_component.nr_params()])
    del param_values[:model_component.nr_params()]
    return ret

def get_inits(model_components: list[ModelComponent]) -> list[float]:
    ret = []
    for mc in model_components:
        for p in mc.params.values():
            ret.append(p.init)
    return ret

def get_upper_bounds(model_components: list[ModelComponent]) -> list[float]:
    ret = []
    for mc in model_components:
        for p in mc.params.values():
            ret.append(p.max)
    return ret

def get_lower_bounds(model_components: list[ModelComponent]) -> list[float]:
    ret = []
    for mc in model_components:
        for p in mc.params.values():
            ret.append(p.min)
    return ret


In [None]:
class Fittable:
    def __init__(self, model: ModelComponent, fit_range: tuple[float, float]):
        self.model = model
        self.fit_range = fit_range
    def fit(self, bin_weights, bin_centers): # raises RuntimeError if fit failed
        fit_range_mask = (bin_centers >= self.fit_range[0]) & (bin_centers <= self.fit_range[1])
        range_bin_weights = bin_weights[fit_range_mask]
        range_bin_centers = bin_centers[fit_range_mask]
        def model_fcn(x, *p):
            params = list(p)
            ret = evaluate(self.model, x, params)
            assert len(params) == 0
            return ret
        return scipy.optimize.curve_fit(
                model_fcn, range_bin_centers, range_bin_weights, p0=get_inits([self.model]), 
                bounds=(get_lower_bounds([self.model]), get_upper_bounds([self.model])))
    def draw(self, ax, params, color):
        xx = np.linspace(self.fit_range[0], self.fit_range[1], 1000)
        ax.plot(xx, self.model.eval(xx, params), color=color)
    

In [None]:
def check_fit_results(gausses: list[Gauss]) -> None:
    gauss_means = [g.params["mean"].result for g in gausses]
    if len(gauss_means) < 2:
        raise ResultCheckError(f"Too little gausses {len(gauss_means)}")
    for i, mean in enumerate(gauss_means):
        peak_expect = i+1
        if mean < peak_expect-0.45 or mean > peak_expect+0.45:
            raise ResultCheckError(f"Mean of PE peak #{peak_expect} out of range: {mean}")
        if i > 0:
            if abs(gauss_means[i] - gauss_means[i-1] - 1) > 0.2:
                raise ResultCheckError(f"Distance between mean of PE peaks {peak_expect-1},{peak_expect} too far off 1: {gauss_means[i] - gauss_means[i-1]}")



def advanced_calibration(
        precalibrated_histo: dict[str, np.typing.NDArray[Any]],
        params: Mapping[str, Any], *,
        ax = None, nofit=False, verbosity = 0
        ) -> dict[str, float]:
    
    n = precalibrated_histo["n"]
    be = precalibrated_histo["be"]
    be_mid = (be[:-1] + be[1:]) / 2
    assert len(be_mid) == len(be) - 1

    fit_range = (params.get("fit_range_begin", 0.5), params.get("fit_range_end", 3.5))
    fit_range_mask = (be_mid >= fit_range[0]) & (be_mid <= fit_range[1])

    max_in_range = np.max(n[fit_range_mask])

    gauss1 = Gauss((1, 0.5, 1.5), 0.1, (max_in_range/5, 0, np.inf))
    gauss2 = Gauss((2, 1.5, 2.5), 0.1, (max_in_range/5, 0, np.inf))
    gauss3 = Gauss((3, 2.5, 3.5), 0.1, (max_in_range/10, 0, np.inf))
    expodec = ExpoDec((2, 0, np.inf), (max_in_range/2, 0, np.inf))
    linear = Linear((max_in_range/100, 0, np.inf), -10)
    th = TwoHyperbole(max_in_range/2, 100, (0,-1,1))

    use_combo = False
    fittables: list[Fittable] = []
    
    if use_combo:
        fittables.append(Fittable(SumModel({
            "gauss1": gauss1, "gauss2": gauss2, "gauss3": gauss3, "expodec": expodec, "linear": linear
            }), fit_range))
    else:
        fittables += [Fittable(gauss1, (0.85, 1.15)), Fittable(gauss2, (1.85, 2.15)), Fittable(gauss3, (2.85, 3.15))]

    failure: str = ""
    try:
        try:
            if nofit:
                raise RuntimeError("No fit performed, as requested")
            for fi in fittables:
                fitted_params, pcov = fi.fit(n, be_mid)
                fi.model.set_result_params(fitted_params)
        except RuntimeError as e:
            for fi in fittables:
                fi.model.set_result_params(get_inits([fi.model]))
            failure = "fit"
            raise ResultCheckError(e) from e
        else:
            try:
                check_fit_results([gauss1, gauss2, gauss3])
                pass
            except ResultCheckError as e:
                failure = "check"
                raise
    except ResultCheckError:
        raise
    else:
        #TODO: do calibration in this case!
        return {"slope": np.nan, "offset": np.nan}
    finally: # runs in any case; exception or not
        if ax is not None:
            ax.stairs(n, be)
            match failure:
                case "":
                    color="green"
                case "fit":
                    color="red"
                case "check":
                    color="orange"
            for fi in fittables:
                fi.draw(ax, fi.model.get_result_params(), color)
            ax.set_ylim(((np.min(n) if np.min(n) > 0 else 0.5)*0.9, np.max(n)*1.1))

In [None]:
fig, ax = plt.subplots()
ax.set_yscale('log')
advanced_calibration(calibrated_histos["S007"], {"fit_range_end": 3.8}, ax=ax)

In [None]:
def multi_advanced_calibration(calibrated_histo_dict, 
                             calibration_defaults: dict[str, Any], 
                             calibration_overrides: dict[str, dict[str, Any]] = {},
                             draw = False, 
                             nodraw_axes = False, 
                             verbosity = 0
                             ) -> dict[str, dict[str, float]]:
    """Performs advanced_calibration for all channels present in calibrated_histo_dict"""
    ret = {}
    if draw:
        fig, ax = plt.subplots(10, 6, figsize=(20,20))
        ax_iter = iter(ax.ravel())
    nr_unsuccessful_calibs = 0
    for name, calibrated_histo in calibrated_histo_dict.items():
        if draw:
            ax = next(ax_iter)
            ax.set_yscale("log")
            if nodraw_axes:
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            ax.set_title(name, fontsize=10)
        else:
            ax = None
        try:
            calib_results = advanced_calibration(
                calibrated_histo,
                calibration_defaults | calibration_overrides.get(name, {}),
                ax=ax,  verbosity=verbosity)
            ret[name] = calib_results
        except ResultCheckError as e:
            print(f"Calibration failed for {name}: {e}")
            ret[name] = {"slope": np.nan, "offset": np.nan}
            nr_unsuccessful_calibs += 1
    
    if nr_unsuccessful_calibs > 0 and verbosity >= -1:
        print(f"WARNING: {nr_unsuccessful_calibs} calibrations failed!")
    elif verbosity >= 0:
        print("Info: All calibrations successful! :)")

    if draw:
        fig.tight_layout()
        if nodraw_axes: # have to do this also for non-drawn plots
            for ax in ax_iter:
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            fig.subplots_adjust(wspace=0) # , hspace=0)
    return ret

In [None]:
multi_advanced_calibration(calibrated_histos, {"fit_range_end": 3.8}, draw=True)

In [None]:
# energies list was introduced here as sorted list of name, data tuples

In [None]:

# OLD BACKUP
def advanced_calibration(
        precalibrated_histo: dict[str, np.typing.NDArray[Any]],
        params: Mapping[str, Any], *,
        ax = None, nofit=False, verbosity = 0
        ) -> dict[str, float]:
    
    n = precalibrated_histo["n"]
    be = precalibrated_histo["be"]
    be_mid = (be[:-1] + be[1:]) / 2
    assert len(be_mid) == len(be) - 1

    fit_range = (params.get("fit_range_begin", 0.5), params.get("fit_range_end", 3.5))
    fit_range_mask = (be_mid >= fit_range[0]) & (be_mid <= fit_range[1])
    range_n = n[fit_range_mask]
    range_be_mid = be_mid[fit_range_mask]

    gauss1 = Gauss((1, 0.5, 1.5), 0.1, (np.max(range_n)/5, 0, np.inf))
    gauss2 = Gauss((2, 1.5, 2.5), 0.1, (np.max(range_n)/5, 0, np.inf))
    gauss3 = Gauss((3, 2.5, 3.5), 0.1, (np.max(range_n)/10, 0, np.inf))
    expodec = ExpoDec((2, 0, np.inf), (np.max(range_n)/2, 0, np.inf))
    linear = Linear((np.max(range_n)/100, 0, np.inf), -10)
    th = TwoHyperbole(np.max(range_n)/2, 100, (0,-1,1))
    background = SumModel({"expodec": expodec, "linear": linear})
    model_components = [gauss1, gauss2, gauss3, background]

    def model(x, *p):
        params = list(p)
        ret = evaluate(gauss1, x, params)
        ret += evaluate(gauss2, x, params)
        ret += evaluate(gauss3, x, params)
        #ret += evaluate(expodec, x, params)
        #ret += evaluate(linear, x, params)
        ret += evaluate(background, x, params)
        if len(params) != 0:
            raise ValueError("Model messed up")
        return ret

    failure: str = ""
    try:
        try:
            if nofit:
                raise RuntimeError("No fit performed, as requested")
            fitted_params, pcov = scipy.optimize.curve_fit(
                model, range_be_mid, range_n, p0=get_inits(model_components), 
                bounds=(get_lower_bounds(model_components), get_upper_bounds(model_components)))
        except RuntimeError as e:
            #print(f"ERROR: Fit failed: {e}")
            fitted_params = get_inits(model_components)
            failure = "fit"
            raise ResultCheckError(e) from e
        else:
            try:
                check_fit_results(fitted_params[:9])
            except ResultCheckError as e:
                #print(f"ERROR: Fit result check failed: {e}")
                failure = "check"
                raise
    except ResultCheckError:
        raise
    else:
        #TODO: do calibration in this case!
        return {"slope": np.nan, "offset": np.nan}
    finally: # runs in any case; exception or not
        if verbosity >= 1:
            print(*fitted_params)
        if ax is not None:
            ax.stairs(n, be)
            xx = np.linspace(range_be_mid[0], range_be_mid[-1], 1000)
            match failure:
                case "":
                    color="green"
                case "fit":
                    color="red"
                case "check":
                    color="orange"
            ax.plot(xx, model(xx, *fitted_params), color=color)
            ax.set_ylim(((np.min(n) if np.min(n) > 0 else 0.5)*0.9, np.max(n)*1.1))

## Unrefactored code below here ------

In [None]:
fig, ax = plt.subplots(10, 6, figsize=(20,20))
ax = ax.ravel()  # Flatten the 2D axes array for easier indexing

for i, (name, data) in enumerate(energies_list):

    kwargs = overwrites_genhist.get(name, {})
    quantile = kwargs.get("quantile", quantile)
    kwargs.pop("quantile", None)

    n, be, bins = gen_hist(data, quantile=quantile, **kwargs)

    # reset quantile value
    quantile = quantile_all

    kwargs = overwrites_peaksearch.get(name, {})
    try:
        peakpos_indices = find_pe_peaks_in_hist(n, be, defaults=defaults, **kwargs)
    except ValueError as e:
        print(f"{name}: {e}")
        continue
        
    peaks = be[peakpos_indices]
    
    # if 1pe peak has double peak structure, take the mean of the double peak as 1pe peak position
    if name in double_1pe_chs:
        mean_1_2 = np.mean(peaks[1:3])
        peaks = np.array([peaks[0], mean_1_2, peaks[3]])

    if len(peaks) > 2:
        gain = peaks[2] - peaks[1]
        c = 1/gain
        offset = 1 - peaks[1] * c # 1pe peak at 1
    else:
        gain = peaks[1] - peaks[0]
        c = 1/gain
        offset = 1 - peaks[1] * c # 1pe peak at 1    
        
        
    # save vals
    if name not in out_dict:
        out_dict[name] = {}
        
    out_dict[name]["slope"] = c
    out_dict[name]["offset"] = offset
    
    data = c * data + offset

    bins = np.linspace(0,4.5,100)
    ax[i].hist(data, bins, histtype="step", linewidth=2)
    bin_width = bins[1] - bins[0]
    
    for x in range(1,5):
        ax[i].axvline(x=x, ls="--", color="grey")
    ax[i].set_yscale("log")
    ax[i].set_title(name, fontsize=10)

fig.tight_layout()
# fig.savefig(folder_path + "/spectra_fom.png")

In [None]:
import yaml

# Helper to recursively convert NumPy scalars to Python scalars
def convert_numpy(obj):
    if isinstance(obj, dict):
        return {k: convert_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy(i) for i in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.generic,)):  # catches np.float64, np.int64, etc.
        return obj.item()
    else:
        return obj

clean_dict = convert_numpy(out_dict)

with open("./r004_simple_cal_params.yaml", "w") as yaml_file:
    yaml.dump(clean_dict, yaml_file, default_flow_style=False, sort_keys=False)


# Accurate calibration

In [None]:
from sklearn.mixture import GaussianMixture
from scipy.stats import norm

from iminuit import Minuit
from iminuit.cost import LeastSquares

In [None]:
import yaml 
dsp_file = "./l200-p15-r004-phy-20250807T150028Z-tier_dsp_sigma2_snr.lh5"

with open("./r004_simple_cal_params.yaml", "r") as yaml_file:
    params_dict = yaml.safe_load(yaml_file)

In [None]:
energies_list = []
energies_dict = {}

names = []

for ch in keys:
    energies = lh5.read_as(f"ch{ch}/dsp/energy", dsp_file, library="ak")
    energies = np.array(ak.flatten(energies))
    
    name = chmap.map("daq.rawid")[ch].name
    names.append(name)
    
    c = params_dict[name]["slope"]
    offset = params_dict[name]["offset"]

    data = c * energies + offset

    energies_list.append((name, data))
    energies_dict[name] = data

energies_list.sort(key=lambda x: x[0])

In [None]:
pe_thresh = {name: 0.7 for name in names}

In [None]:
def apply_GMM_fit(amp_simple_cal, pe_range=(0.7, 4.5), n_components=12, pe_peaks_model={1:1,2:2,3:3,4:4}, pe_range_peak=0.4):

    # apply range mask
    mask = (amp_simple_cal >= pe_range[0]) & (amp_simple_cal <= pe_range[1])
    amp = amp_simple_cal[mask]

    dmat = np.reshape(amp, (len(amp), 1))
    gm = GaussianMixture(n_components=n_components, covariance_type='diag', max_iter=50)
    gm.fit(dmat)
    
    means = gm.means_.flatten()
    std_devs = np.sqrt(gm.covariances_.flatten())
    weights = gm.weights_
    
    pe_peak_mean = []
    pe_peak_std = []
    for PE in pe_peaks_model.values():
        range_peak = (PE-pe_range_peak, PE+pe_range_peak)
        mask = (means >= range_peak[0]) & (means <= range_peak[1])
        
        mean_PE = np.dot(means[mask], weights[mask])/np.sum(weights[mask])
        pe_peak_mean.append(mean_PE)
        
        std_PE = np.dot(std_devs[mask], weights[mask])/np.sum(weights[mask])
        pe_peak_std.append(std_PE)
        
    return gm, dmat, pe_peaks_model, pe_peak_mean, pe_peak_std


def plot_GMM_fit(gm, dmat, pe_peak_mean, pe_peak_std, ax=None):
    
    if ax is None:
        fig, ax = plt.subplots()
        
    means = gm.means_.flatten()
    std_devs = np.sqrt(gm.covariances_.flatten())
    weights = gm.weights_
        
    x = np.linspace(np.min(dmat), np.max(dmat), 1000).reshape(-1, 1)

    ax.hist(dmat, bins=100, density=True, alpha=0.6, color='gray', label="Data")

    # Plot each Gaussian component individually
    for i in range(len(means)):
        gaussian_curve = weights[i] * norm.pdf(x, means[i], std_devs[i])
        ax.plot(x, gaussian_curve) #, label=f"Gaussian {i+1}\nMean: {means[i]:.2f}, Sigma: {std_devs[i]:.2f}, Weight: {weights[i]:.2f}")
    
    for i, pe in enumerate(pe_peak_mean):
        ax.axvline(pe)
        ax.axvspan(xmin=pe-pe_peak_std[i], xmax=pe+pe_peak_std[i], alpha=0.4)
        
    # Plot the overall GMM fit
    ax.plot(x, np.exp(gm.score_samples(x)), color='black', lw=2, label='Overall GMM')

    ax.set_xlabel('amplitude simple cal')
    ax.set_ylabel('density')
    ax.legend()
    
    
def line(x, α, β):
    return α + x * β 

  
def apply_calibration_line_fit(data_x, data_y, data_yerr):

    least_squares = LeastSquares(data_x, data_y, data_yerr, line)
    m = Minuit(least_squares, α=0, β=1)
    m.migrad()
    m.hesse() 

    β = m.values["β"]
    α = m.values["α"]
    
    # reverse engineer slope and offset such that amp_actual = slope * amp_data + offset
    slope = 1/β
    offset = - α/β
    
    return m, slope, offset


def plot_calibration_line_fit(data_x, data_y, data_yerr, m, ax=None):
    
    if ax is None:
        fig, ax = plt.subplots()
        
    data_x = np.array(data_x)
       
    ax.errorbar(data_x, data_y, data_yerr, fmt="ok", label="data")
    ax.plot(data_x, line(data_x, *m.values), label="fit")

    # display legend with some fit info
    fit_info = [
        f"$\\chi^2$/$n_\\mathrm{{dof}}$ = {m.fval:.1f} / {m.ndof:.0f} = {m.fmin.reduced_chi2:.1f}",
    ]
    for p, v, e in zip(m.parameters, m.values, m.errors):
        fit_info.append(f"{p} = ${v:.3f} \\pm {e:.3f}$")

    ax.legend(title="\n".join(fit_info), frameon=False)
    ax.set_xlabel("true p.e.")
    ax.set_ylabel("p.e. data after simple cal")
    
    
def apply_accurate_cal(amp_simple_cal, slope, offset):
    
    return slope * amp_simple_cal + offset

In [None]:
overrides_GMM = {
}

In [None]:
from tqdm import tqdm

In [None]:
fig, ax = plt.subplots(10, 6, figsize=(20,20))
ax = ax.ravel()

fig2, ax2 = plt.subplots(10, 6, figsize=(20,20))
ax2 = ax2.ravel()

out_dict = {}

for i, (name, amp_simple_cal) in tqdm(enumerate(energies_list), total=len(energies_list)):
    
    pe_peaks_model={1:1,2:2,3:3,4:4}
    pe_range = (pe_thresh[name], 4.5)
    pe_range_peak = 0.4
    
    # Update parameters from overrides if they exist
    if name in overrides_GMM:
        params = overrides_GMM[name]
        pe_peaks_model = params.get("pe_peaks_model", pe_peaks_model)
        pe_range = params.get("pe_range", pe_range)
        pe_range_peak = params.get("pe_range_peak", pe_range_peak)

    # Pass the parameters to fit_GMM
    gm, dmat, pe_peaks_model, pe_peak_mean, pe_peak_std = apply_GMM_fit(
        amp_simple_cal, 
        pe_range=pe_range, 
        n_components=12,
        pe_peaks_model=pe_peaks_model,
        pe_range_peak=pe_range_peak
    )
    
    plot_GMM_fit(gm, dmat, pe_peak_mean, pe_peak_std, ax=ax[i])
    
    data_x = np.array(list(pe_peaks_model.keys()))
    m, slope, offset = apply_calibration_line_fit(data_x, pe_peak_mean, pe_peak_std)  
    
    amp_accurate_cal = apply_accurate_cal(amp_simple_cal, slope, offset)
    
    bins = np.linspace(0,4.5,100)
    ax2[i].hist(amp_simple_cal, bins, histtype="step", label="simple cal")
    ax2[i].hist(amp_accurate_cal, bins, histtype="step", label="GMM cal")
    ax2[i].legend()
    ax2[i].set_yscale("log")
    
    ax[i].set_title(name, fontsize=10)
    ax2[i].set_title(name, fontsize=10)
    
    # save vals
    if name not in out_dict:
        out_dict[name] = {}
        
    out_dict[name]["slope"] = slope
    out_dict[name]["offset"] = offset
    
fig.tight_layout()
fig2.tight_layout()