In [1]:
###################################################################################################
#
# PairIdentification.py
#
# Copyright (C) by Andreas Zoglauer & Harrison Costatino.
#
# Please see the file LICENSE in the main repository for the copyright-notice.
#
###################################################################################################



###################################################################################################

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import numpy as np

#from mpl_toolkits.mplot3d import Axes3D
#import matplotlib.pyplot as plt

import random

import signal
import sys
import time
import math
import csv
import os
import argparse
import logging
import yaml
from datetime import datetime
from functools import reduce


print("\nPair Identification")
print("============================\n")



# Step 1: Input parameters
###################################################################################################


# Default parameters

# Split between training and testing data
TestingTrainingSplit = 0.1

MaxEvents = 1000

# File names
FileName = "PairIdentification.p1.sim.gz"
GeometryName = "$(MEGALIB)/resource/examples/geomega/GRIPS/GRIPS.geo.setup"


# Set in stone later
TestingTrainingSplit = 0.8

OutputDirectory = "Results"


parser = argparse.ArgumentParser(description='Perform training and/or testing of the pair identification machine learning tools.')
parser.add_argument('-d', '--datatype', default='tm2', help='One of: tm1: toy modle #1, tm2: toy model #2, f: file')
parser.add_argument('-f', '--filename', default='PairIdentification.p1.sim.gz', help='File name used for training/testing')
parser.add_argument('-m', '--maxevents', default='100', help='Maximum number of events to use')
parser.add_argument('-s', '--testingtrainigsplit', default='0.1', help='Testing-training split')
parser.add_argument('-b', '--batchsize', default='16', help='Batch size')

# Command line arguments for build model, to remove dependency on .yaml
parser.add_argument('--model_type', default='gnn_segment_classifier', help='model_type')
parser.add_argument('--optimizer', default='Adam', help='optimizer')
parser.add_argument('--learning_rate', default='0.001', help='learning_rate')
parser.add_argument('--loss_func', default='BCELoss', help='loss_func')
parser.add_argument('--input_dim', default='3', help='input_dim')
parser.add_argument('--hidden_dim', default='64', help='hidden_dim')
parser.add_argument('--n_iters', default='100', help='n_iters')
# parser.add_argument('--hidden_activation', default='nn.Tanh', help='hidden_activation')


args = parser.parse_args()

DataType = args.datatype

if args.filename != "":
  FileName = args.filename

if int(args.maxevents) >= 10:
  MaxEvents = int(args.maxevents)

if int(args.batchsize) >= 0:
  BatchSize = int(args.batchsize)

if float(args.testingtrainigsplit) >= 0.05:
  TestingTrainingSplit = float(args.testingtrainigsplit)


if os.path.exists(OutputDirectory):
  Now = datetime.now()
  OutputDirectory += Now.strftime("_%Y%m%d_%H%M%S")

os.makedirs(OutputDirectory)



###################################################################################################
# Step 2: Global functions
###################################################################################################


# Take care of Ctrl-C
Interrupted = False
NInterrupts = 0
def signal_handler(signal, frame):
  global Interrupted
  Interrupted = True
  global NInterrupts
  NInterrupts += 1
  if NInterrupts >= 2:
    print("Aborting!")
    sys.exit(0)
  print("You pressed Ctrl+C - waiting for graceful abort, or press  Ctrl-C again, for quick exit.")
signal.signal(signal.SIGINT, signal_handler)


# Everything ROOT related can only be loaded here otherwise it interferes with the argparse
from EventData import EventData

# Load MEGAlib into ROOT so that it is usable
import ROOT as M
M.gSystem.Load("$(MEGALIB)/lib/libMEGAlib.so")
M.PyConfig.IgnoreCommandLineOptions = True



###################################################################################################
# Step 3: Create some training, test & verification data sets
###################################################################################################


# Read the simulation file data:
DataSets = []
NumberOfDataSets = 0

if DataType == "tm1":
  for e in range(0, MaxEvents):
    Data = EventData()
    Data.createFromToyModelRealismLevel1(e)
    DataSets.append(Data)
    
    NumberOfDataSets += 1
    if NumberOfDataSets > 0 and NumberOfDataSets % 1000 == 0:
      print("Data sets processed: {}".format(NumberOfDataSets))

elif DataType == "tm2":
  for e in range(0, MaxEvents):
    Data = EventData()
    Data.createFromToyModelRealismLevel2(e)
    DataSets.append(Data)
    
    NumberOfDataSets += 1
    if NumberOfDataSets > 0 and NumberOfDataSets % 1000 == 0:
      print("Data sets processed: {}".format(NumberOfDataSets))

elif DataType == "f":
  # Load geometry:
  Geometry = M.MDGeometryQuest()
  if Geometry.ScanSetupFile(M.MString(GeometryName)) == True:
    print("Geometry " + GeometryName + " loaded!")
  else:
    print("Unable to load geometry " + GeometryName + " - Aborting!")
    quit()


  Reader = M.MFileEventsSim(Geometry)
  if Reader.Open(M.MString(FileName)) == False:
    print("Unable to open file " + FileName + ". Aborting!")
    quit()


  print("\n\nStarted reading data sets")
  NumberOfDataSets = 0
  while NumberOfDataSets < MaxEvents:
    Event = Reader.GetNextEvent()
    if not Event:
      break

    if Event.GetNIAs() > 0:
      Data = EventData()
      if Data.parse(Event) == True:
        if Data.hasHitsOutside(XMin, XMax, YMin, YMax, ZMin, ZMax) == False:
          DataSets.append(Data)
          NumberOfDataSets += 1
          if NumberOfDataSets % 500 == 0:
            print("Data sets processed: {}".format(NumberOfDataSets))

else:
  print("Unknown data type \"{}\" Must be one of tm1, tm2, f".format(DataType))
  quit()

print("Info: Parsed {} events".format(NumberOfDataSets))

# Split the data sets in training and testing data sets

TestingTrainingSplit = 0.75


numEvents = len(DataSets)

numTraining = int(numEvents * TestingTrainingSplit)

TrainingDataSets = DataSets[:numTraining]
TestingDataSets = DataSets[numTraining:]



# For testing/validation split
# ValidationDataSets = TestingDataSets[:int(len(TestingDataSets)/2)]
# TestingDataSets = TestingDataSets[int(len(TestingDataSets)/2):]

print("###### Data Split ########")
print("Training/Testing Split: {}".format(TestingTrainingSplit))
print("Total Data: {}, Training Data: {},Testing Data: {}".format(numEvents, len(TrainingDataSets), len(TestingDataSets)))
print("##########################")


###################################################################################################
# Step 4: Vectorize data using preprocess.py
###################################################################################################

from preprocess import generate_incidence, connect_pos, vectorize_data
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

# Locals
from datasets import get_data_loaders
from trainers import get_trainer

train_Edge_Labels, train_Man_Ri, train_Man_Ro, train_XYZ, train_Type, train_Energy, train_GammaEnergy = vectorize_data(TrainingDataSets)
test_Edge_Labels, test_Man_Ri, test_Man_Ro, test_XYZ, test_Type, test_Energy, test_GammaEnergy = vectorize_data(TestingDataSets)

train_features = [[train_XYZ[i], train_Man_Ri[i], train_Man_Ro[i]] for i in range(train_XYZ.shape[0])]
train_labels = train_Edge_Labels

test_features = [[test_XYZ[i], test_Man_Ri[i], test_Man_Ro[i]] for i in range(test_XYZ.shape[0])]
test_labels = test_Edge_Labels

train_dataset = [[train_features[i],train_labels[i]] for i in range(train_XYZ.shape[0])]
test_dataset = [[test_features[i],test_labels[i]] for i in range(test_XYZ.shape[0])] 

train_data_loader = DataLoader(train_dataset, batch_size=BatchSize)
valid_data_loader = DataLoader(test_dataset, batch_size=BatchSize)


Pair Identification

Welcome to JupyROOT 6.18/04
Added Bremsstrahlung hits
Eliminate hit 1 at 10.5109386307273 12.013107486674315 -3.0
Eliminate hit 2 at 9.976084237680977 11.852026159465815 -4.0
Eliminate hit 8 at 16.114948234500257 9.287914134928519 -8.0
Eliminate hit 13 at 10.5109386307273 12.013107486674315 -3.0
Event ID: 0
  Origin Z: -3
  Gamma Energy: 10000.0
  Hit 1 (origin: 0): type=e, pos=(8.672824947043862, 13.97904892181593, -5.0)cm, E=731.8111664171286keV
  Hit 2 (origin: 1): type=e, pos=(12.932603790342569, 2.3287458271205246, -4.0)cm, E=785.8521747575676keV
  Hit 3 (origin: 2): type=e, pos=(15.51689823385883, 6.09029499591748, -5.0)cm, E=839.1756828347962keV
  Hit 4 (origin: 3): type=e, pos=(16.918593410438984, 5.998528441315643, -6.0)cm, E=891.3568557563303keV
  Hit 5 (origin: 4): type=e, pos=(16.871642445623003, 6.153408544514, -7.0)cm, E=930.4581869772868keV
  Hit 6 (origin: 0): type=p, pos=(10.281155699776162, 11.152052157530202, -2.0)cm, E=847.1931809836892keV
  Hi

Event ID: 80
  Origin Z: 19
  Gamma Energy: 10000.0
  Hit 1 (origin: 0): type=e, pos=(0.3955293917915803, -13.441287875120626, 19.0)cm, E=291.10810307820753keV
  Hit 2 (origin: 1): type=e, pos=(-1.0174171735405153, -13.392749043372827, 18.0)cm, E=782.323618540683keV
  Hit 3 (origin: 2): type=e, pos=(1.4871154867576402, -16.28589293664775, 19.0)cm, E=836.6976313810334keV
  Hit 4 (origin: 3): type=e, pos=(-0.4167947995607919, -14.808159070640437, 18.0)cm, E=883.4772154153314keV
  Hit 5 (origin: 4): type=e, pos=(-0.8680146259907235, -12.975537282714026, 19.0)cm, E=923.414646537178keV
  Hit 6 (origin: 5): type=e, pos=(1.5480926815834943, -13.788867075190863, 20.0)cm, E=420.30822361909463keV
  Hit 7 (origin: 1): type=p, pos=(0.24186738518574008, -13.80616897501541, 20.0)cm, E=631.5223001388924keV
  Hit 8 (origin: 7): type=p, pos=(0.2554779229861913, -13.18223602725052, 21.0)cm, E=678.3322966465247keV
  Hit 9 (origin: 8): type=p, pos=(1.1064355911418449, -13.98422753543472, 22.0)cm, E=744.46

In [4]:
###################################################################################################
# Step 5: Setting up the neural network
###################################################################################################

# trainer = get_trainer(distributed=args.distributed, output_dir=output_dir,
#                           device=args.device, **experiment_config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using", "cuda:0" if torch.cuda.is_available() else "cpu", "for training.")

trainer = get_trainer(name='gnn', device=device)

# Build the model
# trainer.build_model(**model_config)

'''
model_config:
    model_type: 'gnn_segment_classifier'
    input_dim: 3
    hidden_dim: 64
    n_iters: 4
    loss_func: 'BCELoss'
    optimizer: 'Adam'
    learning_rate: 0.001
'''
model_type = args.model_type
optimizer = args.optimizer
learning_rate = float(args.learning_rate)
loss_func = args.loss_func
input_dim = int(args.input_dim)
hidden_dim = int(args.hidden_dim)
n_iters = 5

trainer.build_model(model_type=model_type, optimizer=optimizer, learning_rate=learning_rate, loss_func=loss_func, 
  input_dim=3, hidden_dim=hidden_dim, n_iters=n_iters)

# if not args.distributed or (dist.get_rank() == 0):
#     trainer.print_model_summary()

###################################################################################################
# Step 6: Training the network
###################################################################################################

summary = trainer.train(train_data_loader=train_data_loader,
                        valid_data_loader=valid_data_loader, n_epochs=n_iters)

print('Train Loss Log: ', summary['train_loss'])
print('Final Test Accuracy: ', summary['valid_acc'][-1])
print('Max Test Accuracy: ', max(summary['valid_acc']))


trainer.write_summaries("Results/result", summary)

###################################################################################################
# Step 7: Evaluating the network
###################################################################################################



Using cpu for training.
Batch 0 Loss: 0.7268414497375488
Batch 1 Loss: 0.7170198559761047
Batch 2 Loss: 0.7126250863075256
Batch 3 Loss: 0.6874683499336243
Batch 4 Loss: 0.715714156627655
Batch 0 Loss: 0.6921581029891968
Batch 1 Loss: 0.7009883522987366
Batch 2 Loss: 0.690922200679779
Batch 3 Loss: 0.661226749420166
Batch 4 Loss: 0.7071027755737305
Batch 0 Loss: 0.6775506734848022
Batch 1 Loss: 0.6853464245796204
Batch 2 Loss: 0.6763859987258911
Batch 3 Loss: 0.6451206803321838
Batch 4 Loss: 0.6902924180030823
Batch 0 Loss: 0.6630220413208008
Batch 1 Loss: 0.6678400635719299
Batch 2 Loss: 0.6614580750465393
Batch 3 Loss: 0.63213711977005
Batch 4 Loss: 0.6687606573104858
Batch 0 Loss: 0.6463096737861633
Batch 1 Loss: 0.6501309275627136
Batch 2 Loss: 0.6459763646125793
Batch 3 Loss: 0.6177757978439331
Batch 4 Loss: 0.6474794745445251
Train Loss Log:  [0.7119337797164917, 0.6904796361923218, 0.674939239025116, 0.6586435914039612, 0.6415344476699829]
Final Test Accuracy:  0.872499999999956