# Imports

### Plotting functions ([`pyglotaran_extras`](https://github.com/s-weigand/pyglotaran-extras/commit/20da3593105fb839f86e668dc12dc9ca87c3b9ce) + `matplotlib`)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from pyglotaran_extras.io.boilerplate import setup_case_study
from pyglotaran_extras.plotting.plot_overview import plot_overview
from pyglotaran_extras.plotting.style import PlotStyle

plot_style = PlotStyle()
plt.rc("axes", prop_cycle=plot_style.cycler)
plt.rcParams["figure.figsize"] = (21, 14)

### Analysis functions

In [None]:
from glotaran.analysis.optimize import optimize
from glotaran.io import load_dataset
from glotaran.io import load_model
from glotaran.io import load_parameters
from glotaran.project.scheme import Scheme

# Analysis setup

In [None]:
results_folder, script_folder = setup_case_study(output_folder_name="pyglotaran_examples_results")
results_folder = results_folder / "target_analysis"

# Load files

In [None]:
dataset = load_dataset(script_folder / "data/2008Polli_betacar_chex_sim.ascii")
model = load_model(script_folder / "models/model.yml")
parameters = load_parameters(script_folder / "models/parameters.yml")

# Optimization

In [None]:
scheme = Scheme(
    model=model,
    parameters=parameters,
    data={"dataset1": dataset},
    non_negative_least_squares=False,
    optimization_method="TrustRegionReflection",
    maximum_number_function_evaluations=5,
    #     maximum_number_function_evaluations=1,
)
result = optimize(scheme)

# Results

## RMS

In [None]:
result

## Optimized parameters

In [None]:
result.optimized_parameters

# Plots

## Overview

In [None]:
plot_overview(result.data["dataset1"], linlog=True);

## Data and fits

### Function definitions

In [None]:
import numpy as np
import xarray as xr
from glotaran.project.result import Result
from matplotlib.pyplot import Axes


def plot_data_and_fits(
    result: Result, wavelength: float, axis: Axes, linlog=False, linthresh=1, use_label=True
):
    for dataset_name in result.data.keys():
        result_data = result.data[dataset_name].sel(spectral=[wavelength], method="nearest")
        result_data.data.plot(x="time", ax=axis, label=f"{dataset_name}_data" if use_label else "")
        result_data.fitted_data.plot(
            x="time", ax=axis, label=f"{dataset_name}_fit" if use_label else ""
        )
    if linlog:
        axis.set_xscale("symlog", linthresh=linthresh)
    axis.set_ylabel("Intensity")


def get_overlap(result: Result, coord_name: str = "spectral"):
    minima = []
    maxima = []
    for dataset in result.data.values():
        coord = dataset.coords[coord_name]
        minima.append(coord.values.min())
        maxima.append(coord.values.max())
    return min(minima), max(maxima)


def plot_fit_overview(
    result: Result, axes_shape=(4, 4), linlog=False, linthresh=1, figsize=(30, 15), *args, **kwargs
):
    fig, axes = plt.subplots(*axes_shape, figsize=figsize)
    fig.patch.set_facecolor("white")
    fig.title = "Fit overview"
    wavelengths = np.linspace(*get_overlap(result), num=len(axes.flatten()))
    use_label = True
    for wavelength, axis in zip(wavelengths, axes.flatten()):
        plot_data_and_fits(
            result=result,
            wavelength=wavelength,
            axis=axis,
            linlog=linlog,
            linthresh=linthresh,
            use_label=use_label,
        )
        use_label = False
    fig.legend()
    fig.suptitle("Fit overview", fontsize=28)
    fig.tight_layout()
    return fig, axes

### Plots

In [None]:
np.prod((4, 4))

In [None]:
len(result.data["dataset1"].coords["spectral"])

In [None]:
fig, axes = plot_fit_overview(result, axes_shape=(4, 4), linlog=True, linthresh=1)

## Result dataset

In [None]:
result.data["dataset1"]

In [None]:
result.data["dataset1"].clp_label

## IRF

In [None]:
time_slice_max = 200
plot_time_slice = slice(30, time_slice_max)

In [None]:
irf_data = result.data["dataset1"]["irf"].isel(time=slice(0, time_slice_max))
irf_data_diff1 = irf_data.diff("time")
irf_data_diff2 = irf_data_diff1.diff("time")

In [None]:
irf_data.isel(time=plot_time_slice).plot.line(x="time")
irf_data_diff1.isel(time=plot_time_slice).plot.line(x="time")
irf_data_diff2.isel(time=plot_time_slice).plot.line(x="time")
plt.legend(("IRF", "IRF'", "IRF''"))

In [None]:
"coherent_artifact_1"
for label in result.data["dataset1"].clp_label:
    if label.item().startswith("coherent_artifact"):
        result.data["dataset1"]["clp"].sel(clp_label=label).plot()  # .plot(ylim=(-0.01,0.01))

## DOAS related data variables
damped_oscillation_associated_spectra

damped_oscillation_phase

damped_oscillation_sin

damped_oscillation_cos

### damped_oscillation_phase

In [None]:
result.data["dataset1"]["damped_oscillation_phase"].plot.line(x="spectral");

Solvant is 

In [None]:
result.data["dataset1"]["damped_oscillation_phase"].sel(damped_oscillation="osc6").plot()

In [None]:
result.data["dataset1"]["damped_oscillation_phase"].isel(damped_oscillation=slice(0, 6)).plot.line(
    x="spectral", ylim=(-12, 8)
);

### damped_oscillation_associated_spectra

In [None]:
result.data["dataset1"]["damped_oscillation_associated_spectra"].plot(shading="auto")

In [None]:
result.data["dataset1"]["damped_oscillation_associated_spectra"].plot.line(x="spectral");

In [None]:
result.data["dataset1"]["damped_oscillation_associated_spectra"].sel(
    damped_oscillation=["osc2", "osc18"]
).plot.line(x="spectral");

Scaling: Multiply maxima of time and spectral

In [None]:
result.data["dataset1"]["damped_oscillation_associated_spectra"].isel(
    damped_oscillation=slice(0, 6)
).plot.line(x="spectral", ylim=(0, 4.8));

In [None]:
result.data["dataset1"]["damped_oscillation_associated_spectra"].isel(
    damped_oscillation=slice(14, 19)
).plot.line(x="spectral");

### damped_oscillation_sin

In [None]:
center_wl = 500

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].sel(
    spectral=center_wl, method="nearest"
)  # .plot(shading='auto')

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].sel(
    spectral=center_wl, method="nearest"
).plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].isel(time=slice(30, 350)).sel(
    spectral=center_wl, method="nearest"
).plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].isel(time=plot_time_slice).isel(
    damped_oscillation=slice(14, 19)
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].isel(time=plot_time_slice).isel(
    damped_oscillation=slice(0, 4)
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

Dispersion needs to be subtracted

In [None]:
result.data["dataset1"]["damped_oscillation_sin"].isel(time=plot_time_slice).sel(
    damped_oscillation="osc6"
).plot(center=False)

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=plot_time_slice).sel(
    damped_oscillation="osc6"
).plot(center=False)

### damped_oscillation_cos

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=plot_time_slice).sel(
    spectral=center_wl, method="nearest"
).plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=slice(25, 125)).isel(
    damped_oscillation=slice(12, 13)
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=plot_time_slice).isel(
    damped_oscillation=slice(0, 4)
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=plot_time_slice).isel(
    damped_oscillation=slice(14, 19)
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

In [None]:
result.data["dataset1"]["damped_oscillation_cos"].isel(time=plot_time_slice).sel(
    damped_oscillation=["osc2", "osc18"]
).sel(spectral=center_wl, method="nearest").plot.line(x="time")
result.data["dataset1"]["damped_oscillation_sin"].isel(time=plot_time_slice).sel(
    damped_oscillation=["osc2", "osc18"]
).sel(spectral=center_wl, method="nearest").plot.line(x="time");

# Interactive Plots

In [None]:
from __future__ import annotations

from ipywidgets import interact
from ipywidgets import widgets

In [None]:
dataset = result.data["dataset1"]["damped_oscillation_cos"]

time_slider = widgets.FloatRangeSlider(
    value=[-0.1, 0.5],
    min=dataset.time.min(),
    max=2,
    step=0.01,
    description="time:",
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format=".2f",
    layout=widgets.Layout(width="50%"),
)
spectral_slider = widgets.FloatSlider(
    value=500,
    min=dataset.spectral.min(),
    max=dataset.spectral.max(),
    step=0.1,
    description="spectral:",
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format=".2f",
    layout=widgets.Layout(width="50%"),
)

In [None]:
doas_select = widgets.SelectMultiple(
    options=list(dataset.damped_oscillation.values),
    value=list(dataset.damped_oscillation.values[:5]),
    # rows=10,
    description="DOAS",
    layout=widgets.Layout(height=f"{1.2*len(list(dataset.damped_oscillation.values))}em"),
)

In [None]:
def plot_func_factory(dataset: xr.DataArray):
    time = dataset.time

    def wrapper(spectral: float, time_range: tuple[float, float], doas: list[str]):
        time_mask = np.logical_and(time >= time_range[0], time <= time_range[1])
        doas_mask = np.isin(dataset.damped_oscillation, doas)
        dataset.isel(time=time_mask, damped_oscillation=doas_mask).sel(
            spectral=spectral, method="nearest"
        ).plot.line(x="time")
        plt.show()

    return wrapper


plot_cos = plot_func_factory(result.data["dataset1"]["damped_oscillation_cos"])

interact(plot_cos, spectral=spectral_slider, time_range=time_slider, doas=doas_select)