In [12]:
###################################################################################################
#
# PairIdentification.py
#
# Copyright (C) by Andreas Zoglauer & Harrison Costatino.
#
# Please see the file LICENSE in the main repository for the copyright-notice.
#
###################################################################################################



###################################################################################################

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import tensorflow as tf
import numpy as np

#from mpl_toolkits.mplot3d import Axes3D
#import matplotlib.pyplot as plt

import random

import signal
import sys
import time
import math
import csv
import os
import argparse
from datetime import datetime
from functools import reduce


print("\nPair Identification")
print("============================\n")



# Step 1: Input parameters
###################################################################################################


# Default parameters

UseToyModel = True

# Split between training and testing data
TestingTrainingSplit = 0.1

MaxEvents = 10

# File names
FileName = "PairIdentification.p1.sim.gz"
GeometryName = "$(MEGALIB)/resource/examples/geomega/GRIPS/GRIPS.geo.setup"


# Set in stone later
TestingTrainingSplit = 0.8

OutputDirectory = "Results"


parser = argparse.ArgumentParser(description='Perform training and/or testing of the pair identification machine learning tools.')
parser.add_argument('-f', '--filename', default='PairIdentification.p1.sim.gz', help='File name used for training/testing')
parser.add_argument('-m', '--maxevents', default='1000', help='Maximum number of events to use')
parser.add_argument('-s', '--testingtrainigsplit', default='0.1', help='Testing-training split')
parser.add_argument('-b', '--batchsize', default='128', help='Batch size')

args = parser.parse_args()

if args.filename != "":
  FileName = args.filename

if int(args.maxevents) > 1000:
  MaxEvents = int(args.maxevents)

if int(args.batchsize) >= 16:
  BatchSize = int(args.batchsize)

if float(args.testingtrainigsplit) >= 0.05:
  TestingTrainingSplit = float(args.testingtrainigsplit)


if os.path.exists(OutputDirectory):
  Now = datetime.now()
  OutputDirectory += Now.strftime("_%Y%m%d_%H%M%S")

os.makedirs(OutputDirectory)



###################################################################################################
# Step 2: Global functions
###################################################################################################


# Take care of Ctrl-C
Interrupted = False
NInterrupts = 0
def signal_handler(signal, frame):
  global Interrupted
  Interrupted = True
  global NInterrupts
  NInterrupts += 1
  if NInterrupts >= 2:
    print("Aborting!")
    sys.exit(0)
  print("You pressed Ctrl+C - waiting for graceful abort, or press  Ctrl-C again, for quick exit.")
signal.signal(signal.SIGINT, signal_handler)


# Everything ROOT related can only be loaded here otherwise it interferes with the argparse
from EventData import EventData

# Load MEGAlib into ROOT so that it is usable
import ROOT as M
M.gSystem.Load("$(MEGALIB)/lib/libMEGAlib.so")
M.PyConfig.IgnoreCommandLineOptions = True



###################################################################################################
# Step 3: Create some training, test & verification data sets
###################################################################################################


# Read the simulation file data:
DataSets = []
NumberOfDataSets = 0

if UseToyModel == True:
  for e in range(0, MaxEvents):
    Data = EventData()
    Data.createFromToyModel(e)
    DataSets.append(Data)
    
    NumberOfDataSets += 1
    if NumberOfDataSets > 0 and NumberOfDataSets % 1000 == 0:
      print("Data sets processed: {}".format(NumberOfDataSets))
  
else:
  # Load geometry:
  Geometry = M.MDGeometryQuest()
  if Geometry.ScanSetupFile(M.MString(GeometryName)) == True:
    print("Geometry " + GeometryName + " loaded!")
  else:
    print("Unable to load geometry " + GeometryName + " - Aborting!")
    quit()


  Reader = M.MFileEventsSim(Geometry)
  if Reader.Open(M.MString(FileName)) == False:
    print("Unable to open file " + FileName + ". Aborting!")
    quit()


  print("\n\nStarted reading data sets")
  NumberOfDataSets = 0
  while NumberOfDataSets < MaxEvents:
    Event = Reader.GetNextEvent()
    if not Event:
      break

    if Event.GetNIAs() > 0:
      Data = EventData()
      if Data.parse(Event) == True:
        if Data.hasHitsOutside(XMin, XMax, YMin, YMax, ZMin, ZMax) == False:
          DataSets.append(Data)
          NumberOfDataSets += 1
          if NumberOfDataSets % 500 == 0:
            print("Data sets processed: {}".format(NumberOfDataSets))

print("Info: Parsed {} events".format(NumberOfDataSets))

# Split the data sets in training and testing data sets

TestingTrainingSplit = 0.75


numEvents = len(DataSets)

numTraining = int(numEvents * TestingTrainingSplit)

TrainingDataSets = DataSets[:numTraining]
TestingDataSets = DataSets[numTraining:]



# For testing/validation split
# ValidationDataSets = TestingDataSets[:int(len(TestingDataSets)/2)]
# TestingDataSets = TestingDataSets[int(len(TestingDataSets)/2):]

print("###### Data Split ########")
print("Training/Testing Split: {}".format(TestingTrainingSplit))
print("Total Data: {}, Training Data: {},Testing Data: {}".format(numEvents, len(TrainingDataSets), len(TestingDataSets)))
print("##########################")


###################################################################################################
# Step 4: Setting up the neural network
###################################################################################################



###################################################################################################
# Step 5: Training and evaluating the network
###################################################################################################



Pair Identification

Welcome to JupyROOT 6.18/04
Start: -14.259855655301408, -1.2565185962869618, -13
Event ID: 0
  Origin Z: -13
  Gamma Energy: 10000.0
  Hit 1 (origin: 0): type=m, pos=(-14.259855655301408, -1.2565185962869618, -13.0)cm, E=1354.940004802481keV
  Hit 2 (origin: 1): type=e, pos=(-13.228281084902289, -0.19005484885596213, -14.0)cm, E=884.5847308188852keV
  Hit 3 (origin: 2): type=e, pos=(-12.930140788912544, 0.5864995631932872, -15.0)cm, E=922.7296837069443keV
  Hit 4 (origin: 3): type=e, pos=(-11.977518702782286, 2.111337005313529, -16.0)cm, E=373.8244157345232keV
  Hit 5 (origin: 1): type=p, pos=(-14.312850122625893, -1.8049921164810432, -14.0)cm, E=604.3826092141628keV
  Hit 6 (origin: 5): type=p, pos=(-15.1601139111823, -2.1569808344196817, -15.0)cm, E=637.076078469488keV
  Hit 7 (origin: 6): type=p, pos=(-13.665741072652311, 1.3228162909278502, -16.0)cm, E=685.0203891572861keV
  Hit 8 (origin: 7): type=p, pos=(-12.423470747942654, 2.9473029324374624, -17.0)cm, E=7

In [16]:
def generate_incidence(edges, pos_data):
    n_hits = len(pos_data)
    n_edges = len(edges)
    Ri = np.zeros((n_hits, n_edges), dtype=np.uint8)
    Ro = np.zeros((n_hits, n_edges), dtype=np.uint8)
    
    for i in range(len(edges)):
        point = edges[i]
        from_pt = point[0]
        to_pt = point[1]
        Ro[from_pt][i] = 1
        Ri[to_pt][i] = 1
    
    return Ri, Ro

In [17]:
def connect_pos(pos_data):
    edges = []

    for i in range(len(pos_data)):
        point_A = pos_data[i]
        z_A = point_A[2]

        for j in range(len(pos_data)):
            point_B = pos_data[j]
            z_B = point_B[2]

            if z_B == z_A + 1:
                edges.append((i, j))
                edges.append((j, i))
    print(edges)
    
    return generate_incidence(edges, pos_data)

In [18]:
def vectorize_data(eventArr):
    Ri, Ro = [], []
    xyz = []
    t = []
    E = []
    GE = []
    
    max_hits = 0
    max_edges = 0
    
    #parse events
    for event in eventArr:
        edges = []
        max_hits = max(max_hits, len(event.X))
        
        pos = np.swapaxes(np.vstack((event.X, event.Y, event.Z)), 0, 1)
        for i in range(1,len(event.Origin+1)):
            edges.append((i-1,event.Origin[i-1]-1))
        
        max_edges = max(max_edges, len(edges))
        
        e_Ri, e_Ro = generate_incidence(edges,pos)
        
        Ri.append(e_Ri)
        Ro.append(e_Ro)
        xyz.append(np.hstack((event.X, event.Y, event.Z)))
        t.append(2*(event.Type=='m')+(event.Type=='p'))
        E.append(event.E)
        GE.append(event.GammaEnergy)
    
    #padding
    for i in range(len(Ri)):
        arr = Ri[i]
        padded_arr = np.zeros((max_hits,max_edges))
        padded_arr[:arr.shape[0],:arr.shape[1]] = arr
        Ri[i] = padded_arr
        
        arr = Ro[i]
        padded_arr = np.zeros((max_hits,max_edges))
        padded_arr[:arr.shape[0],:arr.shape[1]] = arr
        Ro[i] = padded_arr
        
        arr = xyz[i]
        padded_arr = np.zeros((max_hits*3))
        padded_arr[:arr.shape[0]] = arr
        xyz[i] = padded_arr
        
        arr = t[i]
        padded_arr = np.zeros((max_hits))
        padded_arr[:arr.shape[0]] = arr
        t[i] = padded_arr
        
        arr = E[i]
        padded_arr = np.zeros((max_hits))
        padded_arr[:arr.shape[0]] = arr
        E[i] = padded_arr
    
    return np.array(Ri), np.array(Ro), np.array(xyz), np.array(t), np.array(E), np.array(GE)

In [19]:
Ri, Ro, xyz, t, E, GE = vectorize_data(TrainingDataSets)

In [20]:
from heptrkx-gnn-tracking.trainers import get_trainer

SyntaxError: invalid syntax (<ipython-input-20-8191b9609366>, line 1)

In [21]:
Ri.shape

(7, 18, 17)

In [22]:
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

# Locals
from datasets import get_data_loaders
from trainers import get_trainer

In [23]:
trainer = get_trainer('gnn')
trainer.build_model()
trainer.print_model_summary()

In [24]:
train_data_loader = DataLoader(train_dataset, batch_size=1)
test_data_loader = DataLoader(valid_dataset, batch_size=1)

NameError: name 'train_dataset' is not defined

In [None]:
summary = trainer.train(train_data_loader=train_data_loader,
                        valid_data_loader=valid_data_loader,
                        **train_config)