In [1]:
from loss_util import get_ptetaphi, preprocess_emdnn_input

In [2]:
import numpy as np
import pandas as pd
import math
import torch
from pyjet import cluster,DTYPE_PTEPM

def jet_particles(raw_path, n_events):
    df = pd.read_hdf(raw_path, stop=n_events)
    all_events = df.values
    rows = all_events.shape[0]
    cols = all_events.shape[1]
    X = []
    # cluster jets and store info
    for i in range(rows):
        pseudojets_input = np.zeros(len([x for x in all_events[i][::3] if x > 0]), dtype=DTYPE_PTEPM)
        for j in range(cols // 3):
            if (all_events[i][j*3]>0):
                pseudojets_input[j]['pT'] = all_events[i][j*3]
                pseudojets_input[j]['eta'] = all_events[i][j*3+1]
                pseudojets_input[j]['phi'] = all_events[i][j*3+2]
        sequence = cluster(pseudojets_input, R=1.0, p=-1)
        jets = sequence.inclusive_jets()[:2] # leading 2 jets only
        if len(jets) < 2: continue
        for jet in jets: # for each jet get (px, py, pz, e)
            if jet.pt < 200 or len(jets)<=1: continue
            n_particles = len(jet)
            particles = np.zeros((n_particles, 6))
            # store all the particles of this jet
            for p, part in enumerate(jet):
                particles[p,:] = np.array([part.px,
                                           part.py,
                                           part.pz,
                                           part.pt,
                                           part.eta,
                                           part.phi])
            X.append(particles)
    X = np.array(X,dtype='O')
    return X

In [3]:
X = jet_particles('/anomalyvol/data/bb_train_sets/bb0/raw/events_LHCO2020_backgroundMC_Pythia.h5', 70)

In [4]:
x = X[2]

In [5]:
ptetaphi = get_ptetaphi(torch.from_numpy(x[:,:3]),torch.zeros(len(x)))

In [6]:
pair = []
for l in X:
    if len(l) == 66:
        pair.append(l)

# gae + emdnn preprocessing

In [7]:
j1 = torch.from_numpy(pair[0][:,:3])
j2 = torch.from_numpy(pair[1][:,:3])
batch = torch.zeros(len(j1),dtype=torch.int64)
j1 = get_ptetaphi(j1, batch)
j2 = get_ptetaphi(j2, batch)
data = preprocess_emdnn_input(j2,j1,batch)

In [8]:
data.x

tensor([[ 1.9498e-03, -8.3790e-01, -5.3505e-01,  1.0000e+00],
        [ 2.5306e-03,  2.8482e-01, -9.0499e-01,  1.0000e+00],
        [ 2.1590e-03, -2.1717e-01, -9.1990e-01,  1.0000e+00],
        [ 1.1247e-03, -2.9198e-01, -8.6075e-01,  1.0000e+00],
        [ 2.5057e-03,  1.8746e-02, -8.0145e-01,  1.0000e+00],
        [ 9.2517e-04,  2.8071e-01, -6.3669e-01,  1.0000e+00],
        [ 1.1366e-03, -1.2008e-02, -6.4753e-01,  1.0000e+00],
        [ 8.0097e-04, -6.4124e-01, -7.8903e-02,  1.0000e+00],
        [ 2.8408e-03, -3.2278e-01, -5.3327e-01,  1.0000e+00],
        [ 2.3148e-03, -6.0289e-01, -9.6227e-02,  1.0000e+00],
        [ 5.2615e-03, -2.1426e-01,  4.4504e-01,  1.0000e+00],
        [ 6.5133e-04,  4.9098e-01,  1.9240e-02,  1.0000e+00],
        [ 1.8773e-03, -4.2714e-01,  2.0753e-01,  1.0000e+00],
        [ 5.7509e-03, -2.6688e-01, -3.6500e-01,  1.0000e+00],
        [ 1.7419e-03, -5.7282e-02, -4.4025e-01,  1.0000e+00],
        [ 2.2708e-03, -4.1457e-01, -1.3709e-01,  1.0000e+00],
        

# emd-training preprocessing

In [9]:
import torch
import itertools
import numpy as np
import energyflow as ef
from torch_geometric.data import Dataset, Data

In [10]:
X2 = [p[:,3:] for p in pair[:2]]

In [11]:
R = 0.4
ONE_HUNDRED_GEV = 100
# clean and store list of jets as particles (pt, eta, phi)
indices = []
Js = []
for i, x in enumerate(X2): 
    # center jet according to pt-centroid
    yphi_avg = np.average(x[:,1:3], weights=x[:,0], axis=0)
    x[:,1:3] -= yphi_avg
    # mask out any particles farther than R=0.4 away from center (rare)
    print(len(x))
    # x = x[np.linalg.norm(x[:,1:3], axis=1) <= R]
    # add to list
    if len(x) == 0: continue
    Js.append(x)
    indices.append(i)

# calc emd between all jet pairs and save datum
jetpairs = [[i, j] for (i, j) in itertools.product(range(len(Js)),range(len(Js)))]
datas = []
for k, (i, j) in enumerate(jetpairs):    
    emdval, G = ef.emd.emd(Js[i], Js[j], R=R, return_flow=True)
    emdval = emdval/ONE_HUNDRED_GEV
    G = G/ONE_HUNDRED_GEV
    Ei = np.sum(Js[i][:,0])
    Ej = np.sum(Js[j][:,0])
    jiNorm = np.zeros((Js[i].shape[0],Js[i].shape[1]+1)) # add a field
    jjNorm = np.zeros((Js[j].shape[0],Js[j].shape[1]+1)) # add a field
    jiNorm[:,:3] = Js[i].copy()
    jjNorm[:,:3] = Js[j].copy()
    jiNorm[:,0] = jiNorm[:,0]/Ei
    jjNorm[:,0] = jjNorm[:,0]/Ej
    jiNorm[:,3] = -1*np.ones((Js[i].shape[0]))
    jjNorm[:,3] = np.ones((Js[j].shape[0]))
    jetpair = np.concatenate([jiNorm, jjNorm], axis=0)
    nparticles_i = len(Js[i])
    nparticles_j = len(Js[j])
    pairs = [[m, n] for (m, n) in itertools.product(range(0,nparticles_i),range(nparticles_i,nparticles_i+nparticles_j))]
    edge_index = torch.tensor(pairs, dtype=torch.long)
    edge_index = edge_index.t().contiguous()
    u = torch.tensor([[Ei/ONE_HUNDRED_GEV, Ej/ONE_HUNDRED_GEV]], dtype=torch.float)
    edge_y = torch.tensor([[G[m,n-nparticles_i] for m, n in pairs]], dtype=torch.float)
    edge_y = edge_y.t().contiguous()

    x = torch.tensor(jetpair, dtype=torch.float)
    y = torch.tensor([[emdval]], dtype=torch.float)

    d = Data(x=x, edge_index=edge_index, y=y, u=u, edge_y=edge_y)         
    datas.append(d)

66
66


In [12]:
pre = [d for d in datas if d.y != 0]

In [13]:
pre[1].x

tensor([[ 1.9498e-03, -8.3790e-01, -5.3505e-01, -1.0000e+00],
        [ 2.5306e-03,  2.8482e-01, -9.0499e-01, -1.0000e+00],
        [ 2.1590e-03, -2.1717e-01, -9.1990e-01, -1.0000e+00],
        [ 1.1247e-03, -2.9198e-01, -8.6075e-01, -1.0000e+00],
        [ 2.5057e-03,  1.8746e-02, -8.0145e-01, -1.0000e+00],
        [ 9.2517e-04,  2.8071e-01, -6.3669e-01, -1.0000e+00],
        [ 1.1366e-03, -1.2008e-02, -6.4753e-01, -1.0000e+00],
        [ 8.0097e-04, -6.4124e-01, -7.8903e-02, -1.0000e+00],
        [ 2.8408e-03, -3.2278e-01, -5.3327e-01, -1.0000e+00],
        [ 2.3148e-03, -6.0289e-01, -9.6227e-02, -1.0000e+00],
        [ 5.2615e-03, -2.1426e-01,  4.4504e-01, -1.0000e+00],
        [ 6.5133e-04,  4.9098e-01,  1.9240e-02, -1.0000e+00],
        [ 1.8773e-03, -4.2714e-01,  2.0753e-01, -1.0000e+00],
        [ 5.7509e-03, -2.6688e-01, -3.6500e-01, -1.0000e+00],
        [ 1.7419e-03, -5.7282e-02, -4.4025e-01, -1.0000e+00],
        [ 2.2708e-03, -4.1457e-01, -1.3709e-01, -1.0000e+00],
        