In [None]:
import os
from glob import glob
import lstmcpipe
print("lstmcpipe", lstmcpipe.__version__)
import lstchain
print("lstchain",  lstchain.__version__)

In [None]:
from lstmcpipe.scripts.script_compare_irfs import plot_comparison

In [None]:
# Standard config production:
std_irf_file = '/fefs/aswg/data/mc/IRF/20200629_prod5_trans_80/zenith_20deg/south_pointing/20210923_v0.7.5_prod5_trans_80_dynamic_cleaning/off0.4deg/20210923_v075_prod5_trans_80_dynamic_cleaning_gamma_off04deg_sensitivity.fits.gz'


In [None]:
# Configs are given as `config_max-depth_min-samples-split`

fits_list = glob('results/config_*/*.fits.gz')

In [None]:
import logging
import matplotlib.pyplot as plt
import ctaplot
import astropy.units as u
# from lstmcpipe.plots.plot_irfs import plot_summary_from_file
from lstmcpipe.plots.plot_irfs import (read_sensitivity_table, 
                                       plot_sensitivity_from_table, 
                                       plot_effective_area_from_file, 
                                       plot_angular_resolution_from_file, 
                                       plot_energy_resolution_from_file,
                                      )

def plot_summary_from_file(filename, axes=None, **kwargs):

    if axes is None:
        fig, axes = plt.subplots(2, 2, figsize=(15, 15))

    sens_table = read_sensitivity_table(filename)
    plot_sensitivity_from_table(sens_table, ax=axes.ravel()[0], **kwargs)
    
    plot_effective_area_from_file(filename, ax=axes.ravel()[1], **kwargs)

    plot_angular_resolution_from_file(filename, ax=axes.ravel()[2], **kwargs)

    plot_energy_resolution_from_file(filename, ax=axes.ravel()[3], **kwargs)

    axes.ravel()[0].get_figure().tight_layout()

    return axes


def plot_comparison(filelist, baseline_index=0, outfile=None, cta_north=False, **plot_kwargs):
    """
    Create a 2x2 plot comparing different sensitivity curves
    Parameters
    ----------
    filelist: list
        File list with sensitivity curves to be compared.
    outfile: str or Path or None
        path to the output file to be saved.
        if None, the figure is not saved
    cta_north: Bool
        Flag to superpose/add (True) or not (False - Default) the CTA North.
        Imported from ctaplot
    """
    # logging.basicConfig(level=logging.INFO)
    # log = logging.getLogger("lstchain MC DL2 to IRF - sensitivity curves")

    # log.info("Starting lstmcpipe compare irfs script")

    fig, axes = plt.subplots(3, 2, figsize=(15, 15))

    for file in filelist:
        # log.info(f"Plotting IRFs from file {file}")
        label = os.path.basename(file)
        plot_summary_from_file(file, axes=axes, label=label, **plot_kwargs)

    if cta_north:
        ctaplot.plot_sensitivity_cta_performance('north', ax=axes.ravel()[0])
        ctaplot.plot_angular_resolution_cta_performance('north', ax=axes.ravel()[2])
        ctaplot.plot_energy_resolution_cta_performance('north', ax=axes.ravel()[3])
        ctaplot.plot_effective_area_cta_performance('north', ax=axes.ravel()[1])

    
    sens_table_baseline = read_sensitivity_table(filelist[baseline_index])
    e = sens_table_baseline["reco_energy_center"]
    w = sens_table_baseline["reco_energy_high"] - sens_table_baseline["reco_energy_low"]
    s_baseline = e ** 2 * sens_table_baseline["flux_sensitivity"]
    ax = axes.ravel()[4]
    for ii, file in enumerate(filelist):
        sens_table = read_sensitivity_table(file)
        s = e ** 2 * sens_table["flux_sensitivity"]
        ax.errorbar(
            e.to_value(u.TeV),
            s/s_baseline,
            xerr=w.to_value(u.TeV) / 2,
            **plot_kwargs,
        )
    ax.set_title('Sensitivity ratio (lower is better)')
    ax.set_xscale('log')
    ax.grid(True, which='both')
        
    for ax in axes.ravel()[[0, 1, 2]]:
        leg = ax.get_legend()
        if leg:
            leg.remove()
        
    fig.tight_layout()
    axes.ravel()[3].legend(loc='lower center', bbox_to_anchor=(0.5, -0.5))
    axes.ravel()[5].set_visible(False)
    # axes = axes[[0,2,1], :]
        
    
    if outfile is not None:
        fig.savefig(outfile, dpi=200, bbox_inches='tight')
    else:
        fig.show()

    return axes

In [None]:
# As a function of min_samples_split
# max_depth=10
fits_list = glob('results/config_10_*/*.fits.gz')
fits_list.append(std_irf_file)
axes=plot_comparison(filelist=fits_list, baseline_index=-1, cta_north=True, ls='--')

In [None]:
# As a function of min_samples_split
# max_depth=30
fits_list = [std_irf_file]
fits_list = []
fits_list.extend(glob('results/config_30_*/*.fits.gz'))
axes=plot_comparison(filelist=fits_list, baseline_index=0, cta_north=False, ls='--')
ax=axes.ravel()[2].set_ylim(0, 0.4)

In [None]:
# As a function of min_samples_split
# max_depth=50
fits_list = glob('results/config_50_*/*.fits.gz')
fits_list.append(std_irf_file)
axes=plot_comparison(filelist=fits_list, baseline_index=-1, cta_north=False, ls='--')
ax=axes.ravel()[2].set_ylim(0, 0.4)

In [None]:
fits_list = glob('results/config_30_*/*.fits.gz')
fits_list.extend(glob('results/config_50_*/*.fits.gz'))
fits_list.append(std_irf_file)
axes=plot_comparison(filelist=fits_list, baseline_index=-1, cta_north=False, ls='--')
ax=axes.ravel()[2].set_ylim(0, 0.4)

In [None]:
# As a function of min_samples_split
# max_depth=50
fits_list = [std_irf_file]
fits_list.extend(glob('results/config_50_*/*.fits.gz'))
fits_list.extend(glob('results/config_100_*/*.fits.gz'))
axes=plot_comparison(filelist=fits_list, baseline_index=0, cta_north=False, ls='--')
ax=axes.ravel()[2].set_ylim(0, 0.4)

In [None]:
fits_list = [std_irf_file]
fits_list.extend(glob('results/config_*_5/*.fits.gz'))
plot_comparison(filelist=fits_list, cta_north=True, ls='--');

In [None]:
fits_list = [std_irf_file]
fits_list.extend(glob('results/config_*_10/*.fits.gz'))
plot_comparison(filelist=fits_list, cta_north=True, ls='--');