## import library

In [10]:
# Libraries
from typing import List, Tuple, Dict, Union, Any, Iterable, Optional

# root I/O
import uproot
import awkward as ak
import numpy as np

# # plotting
# import matplotlib.pyplot as plt
# from matplotlib.ticker import ScalarFormatter

# plotting style
import ROOT as r

# useful
import math
import sys
import os

## definition of constants

In [11]:
OUTPUT_DIR = "./data/old_800mm_notar"
# debug data (100k events)
INPUT_ROOT_PATH = "/home/shiraiwa/myg4work/olddetector/old_800mm_taron/MT-build/develop_data/100k.root"
# all data (60M events)
# INPUT_DATA = "/home/shiraiwa/myg4work/olddetector/old_800mm_taron/MT-build/MT1/all.root"
VERVOSE = True
Y_MAX = 100
Z_POS = 2500
DIST_FR = 800
RESO_X = 32
RESO_Y = 32
MARGIN = 0
LIST_OF_BRANCHES = ['ScPosX','ScPosY','ScPosZ','ScEdep']


## definition of function

In [18]:
class EventFilter:
    """
    A class used to filter events from a ROOT file and convert to dict of numpy.ndarray.

    ...

    Attributes
    ----------
    input_root_path : str
        The path to the input ROOT file.
    output_dir : str
        The directory where the output will be stored.
    th_MeV : float
        The threshold energy in MeV for a channel to be considered active.
    events : dict
        The events data from the ROOT file.

    Methods
    -------
    _root_to_numpy():
        Converts the ROOT file to a numpy array.
    _correct_layer_order():
        Corrects the layer order of the events.
    _get_is_active_channel():
        Returns a boolean array indicating if the channel energy is > th_MeV.
    _filter_empty_events():
        Filters out events with no energy deposition.
    get_active_channel_idx():
        Returns a list of arrays with the channel indices of active channels.
    """
    def __init__(self, input_root_path:str, output_dir:str, th_MeV:float=0.):
        """
        Constructs all the necessary attributes for the EventFilter object.

        Parameters
        ----------
            input_root_path : str
                The path to the input ROOT file.
            output_dir : str
                The directory where the output will be stored.
            th_MeV : float
                The threshold energy in MeV for a channel to be considered active.
        """
        self.input_root_path = input_root_path
        self.output_dir = output_dir
        self.th_MeV = th_MeV
        self.events = self._root_to_numpy()
        self.events = self._correct_layer_order()
        print("n_events before filtering: ", len(self.events['ScEdep']))
        self.events = self._filter_empty_events()
        print("n_events after filtering: ", len(self.events['ScEdep']))

    def _root_to_numpy(self)->Dict[str, np.ndarray]:
        """
        Corrects the layer order of the events from (0, 1, 2, 3) to (3, 2, 0, 1).

        Returns
        -------
        dict
            The events data with corrected layer order.
        """
        tree = uproot.open(self.input_root_path)["MT"]
        events = tree.arrays(LIST_OF_BRANCHES, library="np")
        for k in events.keys():
            events[k] = np.stack(events[k], axis=0)
        return events
    
    def _correct_layer_order(self)->Dict[str, np.ndarray]:
        """
        Corrects the layer order of the events from (0, 1, 2, 3) to (3, 2, 0, 1).

        Returns
        -------
        dict
            The events data with corrected layer order.
        """
        for k in self.events.keys():
            _layers = np.split(self.events[k], 4, axis=1)
            self.events[k] = np.concatenate([_layers[3], _layers[2], _layers[0], _layers[1]], axis=1)
        return self.events
    
    def _get_is_active_channel(self)->np.ndarray:
        """
        Returns a boolean array indicating if the channel energy is > th_MeV.

        Returns
        -------
        np.ndarray
            A boolean array where True indicates that the channel energy is > th_MeV.
        """
        is_active_channel = self.events['ScEdep'] > self.th_MeV
        return is_active_channel
    
    def _filter_empty_events(self)->np.ndarray:
        """
        Filters out events with no energy deposition.

        Returns
        -------
        np.ndarray
            The events data with empty events filtered out.
        """
        isnot_empty = (self.events['ScEdep']).sum(axis=1) > 0
        for k in self.events.keys():
            self.events[k] = self.events[k][isnot_empty]
        return self.events
    
    def get_active_channel_idx(self)->List[np.ndarray]:
        """
        Returns a list of arrays with the channel indices of active channels.

        Returns
        -------
        list
            A list of numpy arrays where each array contains the indices of the active channels for an event.
        """
        is_active_channel = self._get_is_active_channel()
        active_channel_idx = []
        for i in range(len(is_active_channel)):
            active_channel_idx.append(np.where(is_active_channel[i])[0])
        return active_channel_idx

    # def split_array(self)->List[np.ndarray]:
    #     """
    #     Find consecutive hits.

    #     Parameters:
    #     arr: hit_pos_A or hit_pos_B from hit_finder function.

    #     Returns:
    #     An array of contiguous blocks of consecutive hit channels.

    #     Example:
        
    #     ```
    #     arr = [1,2,3,10,11,12,29,30]
    #     diff = [1,1,7,1,1,17,1]
    #     split_indices = [3,6] # This is the index that is not consecutive
    #     [array([1, 2, 3]), array([10, 11, 12]), array([29, 30])]
    #     ```

    #     """
    #     diff = np.diff(arr)
    #     split_indices = np.where(diff != 1)[0] + 1
    #     return np.split(arr, split_indices)

## Execute part


In [13]:
event_filter = EventFilter(INPUT_ROOT_PATH, OUTPUT_DIR)

events = event_filter.events
active_channel_idx = event_filter.get_active_channel_idx()

n_events before filtering:  90000
n_events after filtering:  30750


In [14]:
active_channel_list=event_filter.get_active_channel_idx()

In [15]:
active_channel_list.append(np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]))