In [1]:
import os
import glob
import numpy as np
import pandas as pd
import h5py
import time
import tables as tb
from matplotlib import pyplot as plt
import itertools as it
import torch
from torch_geometric.data import Data

import torch
from torch_geometric.data import InMemoryDataset

#### let's define the dataset class

In [2]:
class MyDataset(InMemoryDataset):
    def __init__(self, root, name, data_list=None, transform=None):
        self.data_list = data_list
        self.name = name
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    @property
    def num_node_labels(self) -> int:
        if self.data.x is None:
            return 0
        for i in range(self.data.x.size(1)):
            x = self.data.x[:, i:]
            if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
                return self.data.x.size(1) - i
        return 0

    @property
    def num_node_attributes(self) -> int:
        if self.data.x is None:
            return 0
        return self.data.x.size(1) - self.num_node_labels

    @property
    def num_edge_labels(self) -> int:
        if self.data.edge_attr is None:
            return 0
        for i in range(self.data.edge_attr.size(1)):
            if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
                return self.data.edge_attr.size(1) - i
        return 0

    @property
    def num_edge_attributes(self) -> int:
        if self.data.edge_attr is None:
            return 0
        return self.data.edge_attr.size(1) - self.num_edge_labels

    def process(self):
        torch.save(self.collate(self.data_list), self.processed_paths[0])
        
    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'

In [3]:
cr_type = ['gamma','proton','electron']
folder_in = '/home/saturn/caph/mpp228/HESS_data/HESS_data_MC/sim_telarray/phase2d/NSB1.00/Desert/Proton_Electron_Gamma-diffuse/20deg/180deg/0.0deg-ws0/Data_h5'
f_list_gamma = glob.glob(folder_in + f'/{cr_type[0]}*')
f_list_proton = glob.glob(folder_in + f'/{cr_type[1]}*')

In [None]:
for group in fin.walk_groups():
    print(group)

In [137]:
for node in fin.walk_nodes():
    print(node)

/ (RootGroup) ''
/configuration (Group) ''
/dl1 (Group) ''
/simulation (Group) ''
/configuration/instrument (Group) ''
/configuration/simulation (Group) ''
/dl1/event (Group) ''
/dl1/monitoring (Group) ''
/simulation/event (Group) ''
/simulation/event/subarray (Group) ''
/simulation/event/telescope (Group) ''
/simulation/event/subarray/shower (Table(1602,), fletcher32, shuffle, blosc:zstd(5)) 'Storage of EventIndexContainer,SimulatedShowerContainer'
/simulation/event/telescope/images (Group) ''
/simulation/event/telescope/parameters (Group) ''
/simulation/event/telescope/images/tel_001 (Table(530,), fletcher32, shuffle, blosc:zstd(5)) 'Storage of TelEventIndexContainer,SimulatedCameraContainer'
/simulation/event/telescope/images/tel_002 (Table(535,), fletcher32, shuffle, blosc:zstd(5)) 'Storage of TelEventIndexContainer,SimulatedCameraContainer'
/simulation/event/telescope/images/tel_003 (Table(573,), fletcher32, shuffle, blosc:zstd(5)) 'Storage of TelEventIndexContainer,SimulatedCamer

#### Get the tel nodes and geometries

In [4]:
def get_hess_geom(fin):
    hess1_cam = fin.get_node('/configuration/instrument/telescope/camera/geometry_HESS-I')
    hess2_cam = fin.get_node('/configuration/instrument/telescope/camera/geometry_HESS-II')
    hess1_cam_geom_xc = hess1_cam.col('pix_x')
    hess1_cam_geom_yc = hess1_cam.col('pix_y')
    hess2_cam_geom_xc = hess2_cam.col('pix_x')
    hess2_cam_geom_yc = hess2_cam.col('pix_y')
    tel_loc = {'ct1': np.array([-0.16, -85.04, 0.97])*100, #x, y, z in cm
               'ct2': np.array([85.07, -0.37, 0.33])*100,
               'ct3': np.array([0.24, 85.04, -0.82])*100,
               'ct4': np.array([-85.04, 0.28, -0.48])*100,
               'ct5': np.array([0., 0., 0.])*100}
    cam_pixels_in_array = dict()
    #calculate pixel coordinates from the center of the array - no z axis at the moment
    for tel in tel_loc.keys():
        if(tel != 'ct5'):
            cam_pixels_in_array[tel] = np.array([tel_loc[tel][0] + hess1_cam_geom_xc, tel_loc[tel][1] + hess1_cam_geom_yc])
        elif(tel == 'ct5'):
            cam_pixels_in_array[tel] = np.array([tel_loc[tel][0] + hess2_cam_geom_xc, tel_loc[tel][1] + hess2_cam_geom_yc])
        else:
            print("what kind of camera is that! please check")
    return cam_pixels_in_array

#### Get the telescope nodes and store them

In [5]:
def get_tel_nodes(fin):
    tel_nodes = dict()
    for inum, node in enumerate(fin.get_node('/dl1/event/telescope/images')):
        tel_name = f'ct{inum+1}'
        tel_nodes[tel_name] = node
    return tel_nodes

In [None]:
trig_events = fin.get_node('/dl1/event/subarray/trigger')
trig_events.read_where(f'(obs_id == {9050}) & (event_id == {5908115})')['tels_with_trigger']

array([[ True,  True,  True,  True,  True]])

In [29]:
tel_nodes['ct3'].read_where(f'(obs_id == {9050}) & (event_id == {5908115})')['image'].shape[0]

2

In [30]:
test = tel_nodes['ct3'].read_where(f'(obs_id == {9050}) & (event_id == {5908115})')['image'][0]

In [32]:
test.shape[0]

960

Closing remaining open files:/home/saturn/caph/mpp228/HESS_data/HESS_data_MC/sim_telarray/phase2d/NSB1.00/Desert/Proton_Electron_Gamma-diffuse/20deg/180deg/0.0deg-ws0/Data_h5/gamma_20deg_180deg_run9050___phase2d2_desert-ws0-nsb1.00_cone5.h5...done/home/saturn/caph/mpp228/HESS_data/HESS_data_MC/sim_telarray/phase2d/NSB1.00/Desert/Proton_Electron_Gamma-diffuse/20deg/180deg/0.0deg-ws0/Data_h5/gamma_20deg_180deg_run9050___phase2d2_desert-ws0-nsb1.00_cone5.h5...done


In [17]:
fin.get_node('/dl1/event/subarray/trigger').read_where(f'(obs_id == {9050}) & (event_id == {7933520})')['tels_with_trigger']

array([[False, False,  True, False,  True]])

In [12]:
print('filename: ', fns_in[0])
fin = tb.open_file(fns_in[0], mode="r")
pe = list()
x = list()
y = list()
#1305 9050 7933520 ct3
obs_id = 9050
event_id = 7933520
for tel_num, tel in enumerate(fin.get_node('/dl1/event/subarray/trigger').read_where(f'(obs_id == {obs_id}) & (event_id == {event_id})')['tels_with_trigger'].flatten()):
    print(tel)
    if(tel and tel_num < 4): #only doing ct1-4 for now, just for testing..
        tel_name = f'ct{tel_num+1}'
        image = tel_nodes[tel_name].read_where(f'(obs_id == {obs_id}) & (event_id == {event_id})')['image'].flatten()
        print(image)
        if (image.sum() < 100):
            continue
        pix_pe_theshold_mask = image > 5    
        pe.append(image[pix_pe_theshold_mask])
        if (tel_num == 2):
            print('tel_name:', tel_name, 'image shape:', image.shape)
        x.append(cam_pixels_in_array[tel_name][0][pix_pe_theshold_mask])
        y.append(cam_pixels_in_array[tel_name][1][pix_pe_theshold_mask])

filename:  /home/saturn/caph/mpp228/HESS_data/HESS_data_MC/sim_telarray/phase2d/NSB1.00/Desert/Proton_Electron_Gamma-diffuse/20deg/180deg/0.0deg-ws0/Data_h5/gamma_20deg_180deg_run9050___phase2d2_desert-ws0-nsb1.00_cone5.h5
[]


#### Now we loop over triggered events in the file and prepare a data list

In [6]:
def get_data_list(fin, cam_pixels_in_array, tel_nodes, cr_type):
    data_list = list()
    trig_events = fin.get_node('/dl1/event/subarray/trigger')
    count = 0
    for ev_num,ev in enumerate(trig_events):
        obs_id = ev['obs_id']
        event_id = ev['event_id']
        #print(ev_num, obs_id, event_id)
        pe = list()
        x = list()
        y = list()
        for tel_num, tel in enumerate(ev['tels_with_trigger']):
            if(tel and tel_num < 4): #only doing ct1-4 for now, just for testing..
                tel_name = f'ct{tel_num+1}'
                #print(ev_num, obs_id, event_id, tel_name)
                image = tel_nodes[tel_name].read_where(f'(obs_id == {obs_id}) & (event_id == {event_id})')['image']
                #print(image.shape)
                if (image.shape[0] == 0): #don't know why it is stored as triggered than!
                    continue
                elif (image.shape[0] > 1): #don't know why this happens at all as well.
                    image = image[0]
                    #print('> 1', image.shape)
                if (image.sum() < 100):
                    continue
                image = image.flatten()
                if (image.shape[0] != cam_pixels_in_array[tel_name][0].shape[0]):
                    count += 1
                    print("this should not happen a bug in h5 file", tel_name, 'image_shape:', image.shape[0], cam_pixels_in_array[tel_name][0].shape)
                    continue

                pix_pe_theshold_mask = image > 5
                pe.append(image[pix_pe_theshold_mask])
                x.append(cam_pixels_in_array[tel_name][0][pix_pe_theshold_mask])
                y.append(cam_pixels_in_array[tel_name][1][pix_pe_theshold_mask])
                #print(image[pix_pe_theshold_mask].shape, tel_name)
        if not pe:
            continue
        pe = np.concatenate(pe).flatten()
        x = np.concatenate(x).flatten()
        y = np.concatenate(y).flatten()
        #print(pe.shape, x.shape, y.shape)
        if(np.sum(pe) > 1000):
            max_pe = np.max(pe)
            max_pe_index = np.argmax(pe)
            #let's define the connections (edges) between the nodes here
            edges = []
            #Two pixels in any of the cameras have the differece from the highest measured signal of less than < pe_level
            #The idea is these signals are coming from the same part of the shower
            for sig in range(len(pe)):
                ratio_pe = pe[sig]/max_pe
                if (ratio_pe > 0.9 and ratio_pe < 1):
                    edges.append([max_pe_index,sig])
            possible_edge_comb = list(it.combinations(np.unique(np.arange(len(pe))),2))

            for i, j in possible_edge_comb:
                dist = np.sqrt((x[j]-x[i])*(x[j]-x[i]) + (y[j]-y[i])*(y[j]-y[i]))
                #connection is defined if:
                #The two pixels which has seen light are < 10 cm from each other
                #10 cm is random for now..
                if (dist > 1.e-2 and dist < 10):
                    edges.append([i,j])
            edge_index = torch.tensor(np.array(edges), dtype=torch.long)
            #print(edge_index.shape)
            nodes = torch.t(torch.tensor(np.array((x,y,pe)), dtype=torch.float))
            if(np.array(edges).max() > len(pe)):
                print('smothing is not right', np.array(edges).max(), len(pe))
                print(np.array(edges), pe)
                break
            if(cr_type == 'gamma'):
                data = Data(x=nodes, edge_index=edge_index.t().contiguous(), y=1)
            else:
                data = Data(x=nodes, edge_index=edge_index.t().contiguous(), y=0)
            data_list.append(data)
    return data_list

In [None]:
%%time
gamma_data_list = list()
for file_in in f_list_gamma[0:10]:
    fin = tb.open_file(file_in, mode="r")
    gamma_data_list += get_data_list(fin, get_hess_geom(fin), get_tel_nodes(fin), 'gamma')

proton_data_list = list()
for file_in in f_list_proton[0:10]:
    fin = tb.open_file(file_in, mode="r")
    proton_data_list += get_data_list(fin, get_hess_geom(fin), get_tel_nodes(fin), 'proton')
final_data_list = gamma_data_list + proton_data_list

In [None]:
len(final_data_list)

In [None]:
dataset = MyDataset('/home/woody/caph/mppi067h/gamma_ray_reconstruction_with_ml/gnn/hdf5_10cm_1000pe','test',final_data_list)
dataset.process()

In [78]:
image[image > 2].shape

(40,)

In [75]:
image.shape

(1, 960)