In [1]:
%load_ext autoreload
%autoreload 2

import uproot
import awkward
import numpy as np
from tqdm import tqdm_notebook as tqdm
import torch
import os
import os.path as osp
import uproot_methods
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data
from dataset import  FalconDataset
import glob

ModuleNotFoundError: No module named 'uproot_methods'

In [183]:
fname = osp.join(os.getcwd(),"data/ntup/","falcon_production_ttbar_gnn_151.root")


In [184]:
rootfile = uproot.open(fname)['fevt']['RHTree']

In [185]:
rootfile.keys()

[b'eventId',
 b'runId',
 b'lumiId',
 b'jetSeed_iphi',
 b'jetSeed_ieta',
 b'jetPt',
 b'jetM',
 b'jetEta',
 b'seljet_genpart_collid',
 b'seljet_genpart_pdgid',
 b'seljet_genpart_charge',
 b'seljet_genpart_px',
 b'seljet_genpart_py',
 b'seljet_genpart_pz',
 b'seljet_genpart_energy',
 b'seljet_genpart_status',
 b'seljet_genpart_motherpdgid',
 b'seljet_genpart_dau1pdgid',
 b'seljet_genpart_dau2pdgid',
 b'seljet_px',
 b'seljet_py',
 b'seljet_pz',
 b'seljet_energy',
 b'seljet_pfcand_px',
 b'seljet_pfcand_py',
 b'seljet_pfcand_pz',
 b'seljet_pfcand_energy',
 b'seljet_pfcand_type',
 b'genpart_collid',
 b'genpart_pdgid',
 b'genpart_charge',
 b'genpart_px',
 b'genpart_py',
 b'genpart_pz',
 b'genpart_energy',
 b'genpart_status',
 b'genpart_motherpdgid',
 b'genpart_dau1pdgid',
 b'genpart_dau2pdgid',
 b'pfjet_px',
 b'pfjet_py',
 b'pfjet_pz',
 b'pfjet_energy',
 b'pfjet_pfcand_px',
 b'pfjet_pfcand_py',
 b'pfjet_pfcand_pz',
 b'pfjet_pfcand_energy',
 b'pfjet_pfcand_type']

In [186]:
px = rootfile.arrays()[b'seljet_pfcand_px']
px = awkward.fromiter(px).flatten()
py = rootfile.arrays()[b'seljet_pfcand_py']
py = awkward.fromiter(py).flatten()
pz = rootfile.arrays()[b'seljet_pfcand_pz']
pz = awkward.fromiter(pz).flatten()
pe = rootfile.arrays()[b'seljet_pfcand_energy']
pe = awkward.fromiter(pe).flatten()

In [187]:
n_jets = len(px.count())
n_jets

1099

In [188]:
def prepare_x(jet_idx):
    pt = []
    peta = []
    pphi = []
    for pfcand_idx in range(len(px[jet_idx])):
        a = uproot_methods.TLorentzVector(px[jet_idx][pfcand_idx], py[jet_idx][pfcand_idx], pz[jet_idx][pfcand_idx], pe[jet_idx][pfcand_idx])  #px, py, pz, E
        pt.append(a.pt)
        peta.append(a.eta)
        pphi.append(a.phi)
    return np.stack((pt, peta, pphi)).T.astype(np.float32)


In [189]:
def construct_edges(X):
    coords = X[:, :2]
    nbrs = NearestNeighbors(algorithm='kd_tree').fit(coords)
    nbrs_sm = nbrs.kneighbors_graph(coords, 10)
    nbrs_sm.setdiag(0) #remove self-loop edges
    nbrs_sm.eliminate_zeros() 
    nbrs_sm = nbrs_sm + nbrs_sm.T
    edge_index = np.array(nbrs_sm.nonzero()).astype(np.int64)
    return edge_index

In [190]:
n_existing = len(glob.glob(os.path.abspath('data/processed') + '/*.pt'))
for jet_idx in range(n_jets):
    X = prepare_x(jet_idx)
    edge_index = construct_edges(X)
    outdata = Data(x=torch.from_numpy(X),edge_index=torch.from_numpy(edge_index))
    torch.save(outdata, osp.join("data/processed", 'data_{}.pt'.format(n_existing + jet_idx)))

In [5]:

dataset = FalconDataset('')
dataset

FalconDataset(9900)

In [3]:
from torch_geometric.data import DataLoader
import energyflow as ef

In [6]:
train_loader = DataLoader(dataset, batch_size=32)

In [7]:
a = next(iter(train_loader))

In [8]:
a

Batch(batch=[3765], edge_index=[2, 42436], x=[3765, 3])

In [22]:
len(train_loader)

310

In [13]:
def preprocess_emd(torch_batch):
    batch_size = torch_batch.batch[-1]
    ret = []
    for batch_idx in range(batch_size):
        ret.append(a.x[a.batch == batch_idx].numpy())
    return ret

In [12]:
ef.emd.emds(preprocess_emd(a, 32))

(32, 32)

In [16]:
preprocess_emd(a)

[array([[ 1.49549011e+02,  4.56690490e-01, -2.82410526e+00],
        [ 3.84739418e+01,  4.37113076e-01, -2.83043599e+00],
        [ 2.37469940e+01,  4.23257679e-01, -2.83046412e+00],
        [ 2.32496452e+01,  5.26221871e-01, -2.15731645e+00],
        [ 2.08056774e+01,  4.87088591e-01, -2.81781197e+00],
        [ 1.93258305e+01,  5.48392057e-01, -2.17918539e+00],
        [ 1.51167116e+01,  5.01126707e-01, -2.14670563e+00],
        [ 1.49939814e+01,  4.48605776e-01, -2.87718701e+00],
        [ 1.20478439e+01,  4.95825410e-01, -2.15136075e+00],
        [ 1.04761763e+01,  2.00990047e-02, -2.51468086e+00],
        [ 9.01668358e+00,  5.18558919e-01, -2.12106204e+00],
        [ 8.74695396e+00, -5.23400540e-03, -2.56869054e+00],
        [ 8.67597389e+00,  4.90514189e-01, -2.17322230e+00],
        [ 8.04955101e+00,  4.94197965e-01, -2.15997243e+00],
        [ 7.36674833e+00,  1.74579471e-02, -2.66714382e+00],
        [ 6.76796246e+00,  5.85266054e-01, -2.66032553e+00],
        [ 6.65323257e+00