# Influence of image cleaning on the reconstruction

In [None]:
import collections
import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from ctapipe.visualization import CameraDisplay
from torchmetrics.classification import AUROC, Accuracy


import gammalearn.data.LST_dataset as dsets
from gammalearn.configuration.constants import GAMMA_ID, PROTON_ID

THIS_DIR = Path.cwd()
OUTPUT_DIR = THIS_DIR

In [None]:
def get_example_basename(path: str) -> str:
    pattern = "dl1.*\.h5"
    match = re.search(pattern, path)
    if match:
        return match.group()[:-3]
    else:
        return ""

In [None]:
# exp files
example_dl1_file = THIS_DIR /  "../../share/data/MC_data/dl1_gamma_example.h5"
assert example_dl1_file.exists(), f"File {example_dl1_file} does not exist"
simu_type = "mc"
example_name_mc = get_example_basename(example_dl1_file.as_posix())

In [None]:
#This define the required parameters loaded from an experiment settings file.
#This function may be used when exp file is not available to test default dataset parameters.

particle_dict = {
    GAMMA_ID: 0,
    PROTON_ID: 1,
}
"""particle_dict is mandatory and maps cta particle types with class id. e.g. gamma (0) is class 0"""
targets = collections.OrderedDict(
    {
        "energy": {
            "output_shape": 1,
            "loss": torch.nn.L1Loss(reduction="none"),
            "loss_weight": 1,
            "metrics": {
                # 'functions': ,
            },
            "mt_balancing": True,
        },
        "impact": {
            "output_shape": 2,
            "loss": torch.nn.L1Loss(reduction="none"),
            "loss_weight": 1,
            "metrics": {},
            "mt_balancing": True,
        },
        "direction": {
            "output_shape": 2,
            "loss": torch.nn.L1Loss(reduction="none"),
            "loss_weight": 1,
            "metrics": {},
            "mt_balancing": True,
        },
        "class": {
            "label_shape": 1,
            "output_shape": len(particle_dict),
            "loss": torch.nn.CrossEntropyLoss(),
            "loss_weight": 1,
            "metrics": {
                "Accuracy_particle": Accuracy(threshold=0.5, task="MULTICLASS", num_classes=len(particle_dict)),
                "AUC_particle": AUROC(
                    task="MULTICLASS",
                    num_classes=len(particle_dict),
                ),
            },
            "mt_balancing": True,
        },
    }
)
"""dict: mandatory, defines for every objectives of the experiment
the loss function and its weight
"""

dataset_class = dsets.MemoryLSTDataset
# dataset_class = dsets.FileLSTDataset
"""Dataset: mandatory, the Dataset class to load the data. Currently 2 classes are available, MemoryLSTDataset that 
loads images in memory, and FileLSTDataset that loads images from files during training.
"""
dataset_parameters = {
    "camera_type": "LST_LSTCam",
    "group_by": "image",
    "use_time": True,
    "particle_dict": particle_dict,
    "targets": list(targets.keys()),
    # 'subarray': [1],
}
"""dict: mandatory, the parameters of the dataset.
camera_type is mandatory and can be:
'LST_LSTCam', 'MST_NectarCam', 'MST_FlashCam', 'SST_ASTRICam', 'SST1M_DigiCam', 'SST_CHEC', 'MST-SCT_SCTCam'.
group_by is mandatory and can be 'image', 'event_all_tels', 'event_triggered_tels'.
particle_dict is mandatory and maps cta particle types with class id. e.g. gamma (0) is class 0, 
proton (101) is class 1 and electron (1) is class 2.
use_time (optional): whether or not to use time information
subarray (optional): the list of telescope ids to select as a subarray
"""

In [None]:
# define the specific dataset parameters for each experiment
dataset_parameters_dvr = dataset_parameters.copy()
dataset_parameters_dvr["use_cleaning_masks"] = True
dataset_parameters_dvr["mask_method"] = "data_reduction_mask"

dataset_parameters_default_tailcut = dataset_parameters.copy()
dataset_parameters_default_tailcut["use_cleaning_masks"] = True
dataset_parameters_default_tailcut["mask_method"] = "tailcuts_standard_analysis"

dataset_parameters_lstchain = dataset_parameters.copy()
dataset_parameters_lstchain["use_cleaning_masks"] = True
dataset_parameters_lstchain["mask_method"] = "precomputed_lstchain"

In [None]:
# load datasets
methods = ["no mask", "dvr", "tailcuts_standard_analysis", "precomputed lstchain"]
dataset_parameters_list = [
    dataset_parameters,
    dataset_parameters_dvr,
    dataset_parameters_default_tailcut,
    dataset_parameters_lstchain,
]

datasets = {}
for method, dataset_parameters in zip(methods, dataset_parameters_list):
    datasets[method] = dataset_class(example_dl1_file, **dataset_parameters)

In [None]:
# find a way to load the same image id in order to plot the same image for all methods
# (removing the black images create a shift in the image id so images[0] may not be the same in all methods)

# lstchain precomputed is the less conservative mask methods
# so retrieve an image_id from this method has more chance to be found in other methods.

image_ids_simus = {}
# get the first event index from the precomputed lstchain method
selected_event_index = 1
method_selected = "precomputed lstchain"

event_id = datasets[method_selected].unique_event_ids[selected_event_index]

# find the image_id in the others datasets:
image_ids = {}
for method in methods:
    image_ids[method] = np.where(datasets[method].unique_event_ids == event_id)[0][0]
image_ids_simus[simu_type] = image_ids


In [None]:
# select image (should be same event for all methods)
images = {method: dataset.images[image_ids_simus["mc"][method]] for method, dataset in datasets.items()}
masks = {
    method: dataset.images_masks[image_ids_simus["mc"][method]]
    for method, dataset in datasets.items()
    if method != "no mask"
}

In [None]:
def print_images(datasets, images, masks, save_fig=False, example_name="") -> None:
    fig, axes = plt.subplots(2, len(datasets.keys()), figsize=(10, 5))
    for i, (method, image) in enumerate(images.items()):
        CameraDisplay(datasets[method].original_geometry, image, ax=axes[0, i], title=method)
        if method != "no mask":
            CameraDisplay(
                datasets[method].original_geometry,
                masks[method],
                ax=axes[1, i],
                title=method,
                show_frame=False,
            )
    for ax in axes.flatten():
        ax.axis("off")

    # save output
    if save_fig:
        output_dir = OUTPUT_DIR.as_posix()
        fig_name = example_name + "_image_" + str(selected_event_index).zfill(3)
        plt.savefig(output_dir + fig_name + ".pdf", dpi=300)
        plt.savefig(output_dir + fig_name + ".png", dpi=300)

In [None]:
datasets["no mask"].images.sum(axis=(1)).shape

In [None]:
# show
save_figure=False
print_images(datasets, images, masks, save_fig=save_figure, example_name=example_name_mc)

---

## Influence of `tailcut_clean` parameters

**Important** : If you use the `share/MC_data/dl1_gamma.h5` file, the difference
in the influence of the `min_number_picture_neighbors` parameter is very limited.
I recommend using a real MC dl1 or a MC\* (with extra noise) to observe the phenomenon with the classical cleaning parameters
```python
PICTURE_THRESH = 8
BOUNDARY_THRESH = 4
```

In order to make the following study relevant, a super high cleaning thresholds are used. 
These values are never used in real case scenarii (even on MC) and have just been finetuned to show a difference.

In [None]:
# We would like to observe the impact on the last parameter of the tailcut method

# take input images from no mask dataset
# apply tailcut with diferent values
from ctapipe.image import tailcuts_clean

# Remark : as mentionned before, these value are really high to show the impact of the tailcut on clean data.
# In practice, the calssic values are 8,4
PICTURE_THRESH = 150
BOUNDARY_THRESH = 120
KEEP_ISOLATED_PIXELS = False
min_number_picture_neighbors = [0, 1, 2, 3]

NB_IMAGES_PLOT_TAILCUT = 5


In [None]:
# for each image, apply tailcut with different min_number_picture_neighbors
masks = {}
offset_id_images = 5
for i_image, image in enumerate(datasets["no mask"].images[offset_id_images : offset_id_images + NB_IMAGES_PLOT_TAILCUT]):
    masks[str(i_image)] = {}

    for neighbour_param in min_number_picture_neighbors:
        mask = tailcuts_clean(
            datasets["no mask"].original_geometry,
            image,
            picture_thresh=PICTURE_THRESH,
            boundary_thresh=BOUNDARY_THRESH,
            keep_isolated_pixels=KEEP_ISOLATED_PIXELS,
            min_number_picture_neighbors=neighbour_param,
        )
        masks[str(i_image)][str(neighbour_param)] = mask


In [None]:
# plot masks
fig, axes = plt.subplots(1 + len(min_number_picture_neighbors), NB_IMAGES_PLOT_TAILCUT, figsize=(20, 10))
# plot images
for i in range(NB_IMAGES_PLOT_TAILCUT):
    CameraDisplay(
        datasets["no mask"].original_geometry,
        datasets["no mask"].images[offset_id_images + i],
        ax=axes[0, i],
        title=f"{i}",
    )
for j, neighbour_param in enumerate(min_number_picture_neighbors):
    for i in range(NB_IMAGES_PLOT_TAILCUT):
        CameraDisplay(
            datasets["no mask"].original_geometry,
            masks[str(i)][str(neighbour_param)],
            ax=axes[1 + j, i],
            title=f"{i},{neighbour_param}",
            show_frame=False,
        )

for ax in axes.flatten():
    ax.axis("off")

In [None]:
nb_image_total = datasets["no mask"].images.shape[0]

nb_black = {str(param): 0 for param in min_number_picture_neighbors}
info = {str(param): {"dl1": [], "max_signal": [], "image_id": []} for param in min_number_picture_neighbors}
images = datasets["no mask"].images
for i in range(images.shape[0]):
    for neighbour_param in min_number_picture_neighbors:
        mask = tailcuts_clean(
            datasets["no mask"].original_geometry,
            images[i],
            picture_thresh=PICTURE_THRESH,
            boundary_thresh=BOUNDARY_THRESH,
            keep_isolated_pixels=KEEP_ISOLATED_PIXELS,
            min_number_picture_neighbors=neighbour_param,
        )
        if mask.sum() == 0:
            nb_black[str(neighbour_param)] += 1
            info[str(neighbour_param)]["dl1"].append(datasets["no mask"].dl1_params[i])
            info[str(neighbour_param)]["max_signal"].append(images[i].max())
            info[str(neighbour_param)]["image_id"].append(i)

In [None]:
ratio_black_images = {}
ratio_black_images = {key: elem / nb_image_total for key, elem in nb_black.items()}

In [None]:
print("Number of black images for min_number neighbors thresholds")
for param, ratio, nb_black_image in zip(ratio_black_images.keys(), ratio_black_images.values(), nb_black.values()):
    print(f"{param} = {nb_black_image} ({ratio:.1%})")

In [None]:
mean_image = np.mean(datasets["no mask"].images.max(axis=1))

print("Mean max per image in the whole mc dataset : {:0.2f}".format(mean_image))

for neighbour_param in info.keys():
    print(f"neighbor parameter: {neighbour_param} -> {np.mean(info[neighbour_param]['max_signal']):0.2f}")
    # print(f"dl1: {np.mean(images_infos[mc_type][neighbour_param]['dl1'])}")