In [1]:
import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

from typing import Optional, Tuple

from typing import Callable, Union
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor

torch.set_default_dtype(torch.float64)

from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
    torch_sparse,
    PairTensor
)

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, to_undirected
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph
from torch_geometric.datasets import QM9
from torch_scatter import scatter
from torch_cluster import knn

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import uproot
import vector
vector.register_awkward()
import awkward as ak

from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.3.1
PyG version 2.5.3


In [2]:
class Jet_Dataset(data.Dataset):

    def __init__(self, dataset_path:str, tree_name:str = 'tree', k:int = 5) -> None:
        """
        Inputs:
            size - Number of data points we want to generate
            std - Standard deviation of the noise (see generate_continuous_xor function)
        """
        super(Jet_Dataset, self).__init__()
        
        
        self.dataset = uproot.open(dataset_path)
        self.tree = self.dataset[tree_name].arrays()
        
        self.num_entries = self.dataset[tree_name].num_entries
        
        self.part_feat = self.dataset[tree_name].keys(filter_name='part_*')
        self.jet_feat = self.dataset[tree_name].keys(filter_name='jet_*')
        self.labels = self.dataset[tree_name].keys(filter_name='labels_*')
        
        self.k = k
        
        
        #self.pc_dataset = [ self.transform_jet_to_point_cloud(idx) for idx in range(self.num_entries-1) ]
        

    def transform_jet_to_point_cloud(self, idx:int) -> Data :
    
        npart = self.tree['jet_nparticles'].to_numpy()[idx:idx+1]
        
        part_feat_list = [ak.flatten(self.tree[part_feat][idx:idx+1]).to_numpy() for part_feat in self.part_feat]
        
        jet_pt = self.tree['jet_pt'].to_numpy()[idx:idx+1]
        jet_eta = self.tree['jet_eta'].to_numpy()[idx:idx+1]
        jet_phi = self.tree['jet_phi'].to_numpy()[idx:idx+1]
        jet_energy = self.tree['jet_energy'].to_numpy()[idx:idx+1]
        jet_tau21 = self.tree['jet_tau2'].to_numpy()[idx:idx+1]/self.tree['jet_tau1'].to_numpy()[idx:idx+1]
        jet_tau32 = self.tree['jet_tau3'].to_numpy()[idx:idx+1]/self.tree['jet_tau2'].to_numpy()[idx:idx+1]
        jet_tau43 = self.tree['jet_tau4'].to_numpy()[idx:idx+1]/self.tree['jet_tau3'].to_numpy()[idx:idx+1]
        
        
        jet_sd_mass = self.tree['jet_sdmass'].to_numpy()[idx:idx+1]
        
        jet_feat = np.stack([jet_pt, jet_eta, jet_phi, jet_energy, jet_tau21, jet_tau32, jet_tau43]).T
              
        #jet_feat = np.repeat(jet_feat, int(npart), axis=0)
             
        part_feat = np.stack(part_feat_list).T
        
        total_jet_feat = part_feat #np.concatenate((part_feat, jet_feat), axis=-1)
        total_jet_feat[np.isnan(total_jet_feat)] = 0.
        
        #print(type(total_jet_feat), 'total_jet_feat shape : ', total_jet_feat.shape)
        
        jet_class = -1
        
        if(self.tree['label_QCD'].to_numpy()[idx:idx+1] == 1) : jet_class = 0
        
        if( (self.tree['label_Tbqq'].to_numpy()[idx:idx+1] == 1) or
            (self.tree['label_Tbl'].to_numpy()[idx:idx+1] == 1)) : jet_class = 2
        
        if( (self.tree['label_Zqq'].to_numpy()[idx:idx+1] == 1) or
            (self.tree['label_Wqq'].to_numpy()[idx:idx+1] == 1)) : jet_class = 0
        
        if( (self.tree['label_Hbb'].to_numpy()[idx:idx+1] == True) or
            (self.tree['label_Hcc'].to_numpy()[idx:idx+1] == True) or
            (self.tree['label_Hgg'].to_numpy()[idx:idx+1] == True) or
            (self.tree['label_H4q'].to_numpy()[idx:idx+1] == True) or
            (self.tree['label_Hqql'].to_numpy()[idx:idx+1] == True) ) : jet_class = 1
        
        part_eta = torch.tensor( ak.flatten(self.tree['part_deta'][idx:idx+1]).to_numpy() )
        part_phi = torch.tensor( ak.flatten(self.tree['part_dphi'][idx:idx+1]).to_numpy() )
        eta_phi_pos = torch.stack([part_eta, part_phi], dim=-1)
        
        edge_index = torch_geometric.nn.pool.knn_graph(x = eta_phi_pos, k = self.k)
        
        src, dst = edge_index
                
        part_del_eta = part_eta[dst] - part_eta[src]
        part_del_phi = part_phi[dst] - part_phi[src]
        
        part_del_R = torch.hypot(part_del_eta, part_del_phi).view(-1, 1) # -- why do we need this view function ? 
        
        data = Data(x=torch.tensor(total_jet_feat), edge_index=edge_index, edge_deltaR = part_del_R)
        data.label = torch.tensor([jet_class])
        data.sd_mass = torch.tensor(jet_sd_mass)
        data.global_data = torch.tensor(jet_feat)
        data.seq_length = torch.tensor(npart)
        
        return data    
        

    def __len__(self) -> int:
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.num_entries#len(self.pc_dataset)
    
    def __getitem__(self, idx:int) -> Data :
        # Return the idx-th data point of the dataset
    
        return self.transform_jet_to_point_cloud(idx)#self.pc_dataset[idx]#data_point, data_label


In [3]:
def build_mlp(in_size, layer_size, depth):
    layers = []

    layers.append(nn.Linear(in_size * 2, layer_size))
    layers.append(nn.BatchNorm1d(layer_size))
    layers.append(nn.ReLU())

    for i in range(depth):
        layers.append(nn.Linear(layer_size, layer_size))
        layers.append(nn.BatchNorm1d(layer_size))
        layers.append(nn.ReLU())

    return nn.Sequential(*layers)

## What is an edge convolution?
If the message function $m_{ij} = h_{\bf \Theta}(x_i, x_j)$, it is called **Edgeconvolution**. 
<img src="edgeconv_cartoon.png" alt="Alternative text" />

The DGCNN paper (https://arxiv.org/pdf/1801.07829.pdf) proposed $m_{ij} = \sigma \Big( \theta_{m}(x_j - x_i) + \phi_m x_i\Big)$

In [4]:
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__(aggr='mean') #  "Mean" aggregation.
        
        self.theta = Sequential(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))
        
        self.phi = Sequential(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        
        out = self.theta(x_j - x_i) + self.phi(x_i)

        return out

<center width="500%"><img src="edgeconv_cartoon.png" width="600px"></center>

## For DynamicEdgeConv, the adjacency is determined on fly

In [5]:
class DynamicEdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels, k):
        super(DynamicEdgeConv, self).__init__(aggr='mean') #  "Mean" aggregation.
        
        self.k = k
        
        self.theta = Sequential(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))
        
        self.phi = Sequential(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, batch=None):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return self.propagate(edge_index=edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        
        out = self.theta(x_j - x_i) + self.phi(x_i)

        return out

## The CMS model 
https://cms-ml.github.io/documentation/inference/particlenet.html
<center width="700%"><img src="particlenet_full_arch.png" width="800px"></center>

## Reference : https://github.com/farakiko/xai4hep/tree/main

In [6]:
class EdgeConv_lrp(MessagePassing):
    """
    Copied from pytorch_geometric source code, with the following edits
    1. torch.cat([x_i, x_j - x_i], dim=-1)) -> torch.cat([x_i, x_j], dim=-1))
    2. retrieve edge_activations
    """

    def __init__(self, nn: Callable, aggr: str = "max", **kwargs):
        super().__init__(aggr=aggr, **kwargs)
        self.nn = nn

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        # propagate_type: (x: PairTensor)
        return (
            self.propagate(edge_index, x=x, size=None),
            self.edge_activations,
        )

    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:
        # self.edge_activations = self.nn(torch.cat([x_i, x_j - x_i], dim=-1))
        # return self.nn(torch.cat([x_i, x_j - x_i], dim=-1))
        self.edge_activations = self.nn(torch.cat([x_i, x_j], dim=-1))
        return self.nn(torch.cat([x_i, x_j], dim=-1))

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"


In [7]:
class EdgeConvBlock(nn.Module):
    def __init__(self, in_size, layer_size, depth):
        super(EdgeConvBlock, self).__init__()
        
        edge_mlp = build_mlp(in_size=in_size, layer_size=layer_size, depth=depth)
        self.edge_conv = EdgeConv_lrp(edge_mlp, aggr="mean")

    def forward(self, x, edge_index):
        return self.edge_conv(x, edge_index)


In [8]:
class ParticleNet(nn.Module):
    def __init__(
        self,
        for_LRP,
        node_feat_size,
        num_classes=1,
        k=16,
        depth=2,
        dropout=False,
        ):
        super(ParticleNet, self).__init__()
        self.for_LRP = for_LRP

        self.node_feat_size = node_feat_size
        self.num_classes = num_classes

        self.k = k
        self.num_edge_conv_blocks = 3

        self.kernel_sizes = [self.node_feat_size, 64, 128, 256]
        self.input_sizes = np.cumsum(self.kernel_sizes)  # [4, 4+64, 4+64+128, 4+64+128+256]

        self.fc_size = 256

        if dropout:
            self.dropout = 0.1
            self.dropout_layer = nn.Dropout(p=self.dropout)
        else:
            self.dropout = None

        # define the edgeconvblocks
        self.edge_conv_blocks = nn.ModuleList()
        for i in range(0, self.num_edge_conv_blocks):
            self.edge_conv_blocks.append(EdgeConvBlock(self.input_sizes[i], self.kernel_sizes[i + 1], depth=depth))

        # define the fully connected networks (post-edgeconvs)
        self.fc1 = nn.Linear(self.input_sizes[-1], self.fc_size)
        self.fc2 = nn.Linear(self.fc_size, self.num_classes)

        self.sig = nn.Sigmoid()

    def forward(self, batch):
        x = batch.x
        y = batch.label
        batch = batch.batch

        # input transformations
        # x[:, 2] = (x[:, 2] - 1.7) * 0.7  # part_pt_log
        # x[:, 3] = (x[:, 3] - 2.0) * 0.7  # part_e_log
        # x[:, 4] = (x[:, 4] + 4.7) * 0.7  # part_logptrel
        # x[:, 5] = (x[:, 5] + 4.7) * 0.7  # part_logerel
        # x[:, 6] = (x[:, 6] - 0.2) * 4.7  # part_deltaR

        # useful placeholders for LRP studies
        edge_activations = {}
        edge_block_activations = {}
        edge_index = {}

        for i in range(self.num_edge_conv_blocks):
            # using only angular coords for knn in first edgeconv block
            edge_index[f"edge_conv_{i}"] = knn_graph(x[:, :2], self.k, batch) if i == 0 else knn_graph(x, self.k, batch)

            out, edge_activations[f"edge_conv_{i}"] = self.edge_conv_blocks[i](x, edge_index[f"edge_conv_{i}"])

            x = torch.cat((out, x), dim=1)  # concatenating with latent features i.e. skip connections per EdgeConvBlock

            edge_block_activations[f"edge_conv_{i}"] = x

        x = global_mean_pool(x, batch)

        x = F.relu(self.fc1(x))
        if self.dropout:
            x = self.dropout_layer(x)
        x = self.fc2(x)
        x = self.sig(x)

        # save different objects if you are running lrp studies
        if self.for_LRP:
            return x, edge_activations, edge_block_activations, edge_index
        else:
            return x, y

## Let's make the dataset and try the forward pass

In [9]:
dataset_path = '/Users/sanmay/Documents/ICTS_SCHOOL/Main_School/JetDataset/'
file_name = dataset_path + 'JetClass_example_100k.root' # -- from -- "https://hqu.web.cern.ch/datasets/JetClass/example/" #
jet_dataset = Jet_Dataset(dataset_path=file_name)

In [10]:
type(jet_dataset)

__main__.Jet_Dataset

In [11]:
data_loader = DataLoader(dataset=jet_dataset, batch_size=5, shuffle = True)

In [12]:
gr_b = next(iter(data_loader))

In [13]:
gr_b

DataBatch(x=[177, 16], edge_index=[2, 885], edge_deltaR=[885, 1], label=[5], sd_mass=[5], global_data=[5, 7], seq_length=[5], batch=[177], ptr=[6])

In [14]:
model_kwargs = {
        "for_LRP": True,
        "node_feat_size": 16,
        "num_classes": 3,
        "k": 5,
        "depth": 3,
        "dropout": True,
    }

model = ParticleNet(**model_kwargs)


In [15]:
x, edge_activations, edge_block_activations, edge_index = model(gr_b)

In [16]:
print(x, edge_activations, edge_block_activations, edge_index)

tensor([[0.4836, 0.4314, 0.4910],
        [0.6756, 0.2997, 0.2838],
        [0.5480, 0.3951, 0.4750],
        [0.5495, 0.4021, 0.4766],
        [0.6125, 0.3088, 0.3717]], grad_fn=<SigmoidBackward0>) {'edge_conv_0': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.9782, 3.4269],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.6033, 3.8062],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.1886, 3.7710],
        ...,
        [0.2477, 0.1399, 0.0701,  ..., 0.1405, 0.0000, 0.0000],
        [0.2030, 0.1233, 0.0116,  ..., 0.2435, 0.0000, 0.0000],
        [0.3155, 0.0900, 0.2734,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>), 'edge_conv_1': tensor([[4.7206, 2.5384, 0.0000,  ..., 0.3769, 0.0000, 6.0670],
        [5.6480, 3.5259, 0.8511,  ..., 0.9331, 0.0000, 5.0319],
        [5.6867, 3.5029, 1.0101,  ..., 0.9252, 0.0000, 4.9931],
        ...,
        [0.0000, 0.0000, 0.4648,  ..., 0.0000, 0.5216, 0.0000],
        [0.0000, 0.0000, 0.4495,  ..., 0.0000, 0.5068, 0.0000],
      

## HW try to visualize the adjacency after each layer of EdgeConv. Please do a training