<a href="https://colab.research.google.com/github/sznajder/Notebooks/blob/master/Set2Graph_GNN_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set2Graph Demo

### A novel set-to-graph model which takes into account information from all tracks in a jet to determine if pairs of tracks originated from a common vertex. It can be used in JetID for b and c jet tagging ( E.Gross, K.Cranmer , J.Shlomi & all ,  https://arxiv.org/pdf/2008.02831.pdf )


Cloned from J.R.Vlimant
https://github.com/vlimant/NNArchTeraScale2021/blob/master/graphs/set2graph.ipynb


credits to **Jonathan Shlomi**. Model from the paper https://arxiv.org/abs/2008.02831 with more code available at https://github.com/jshlomi/SetToGraphPaper

Download the jet dataset from https://zenodo.org/record/4044628 ; do this only once

In [None]:
!curl -O https://zenodo.org/record/4044628/files/valid_data.root

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  114M  100  114M    0     0  3641k      0  0:00:32  0:00:32 --:--:-- 3660k


## the dataloader code

each jet has a different number of tracks, but the model needs all entries in a batch to have the same shape. so we use a sampler to pick jets with the same number of tracks

In [None]:
import os
import uproot
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, Sampler
import setGPU

setGPU: Setting GPU to: 1


In [None]:
class JetGraphDataset(Dataset):
    def __init__(self, fname):
        self.node_features_list = ['trk_d0', 'trk_z0', 'trk_phi', 'trk_ctgtheta', 'trk_pt', 'trk_charge']
    
        self.jet_features_list = ['jet_pt', 'jet_eta', 'jet_phi', 'jet_M']


        with uproot.open(fname) as f:
            tree = f['tree']
            self.n_jets = int(tree.numentries)
            self.n_nodes = np.array([len(x) for x in tree['trk_vtx_index'].array()])

            self.jet_arrays = tree.arrays(self.jet_features_list + self.node_features_list + ['trk_vtx_index'])#,library='np')
            self.sets, self.partitions, self.partitions_as_graphs = [], [], []

        
        for set_, partition, partition_as_graph in self.get_all_items():
            #if torch.cuda.is_available():
            #    set_ = torch.tensor(set_, dtype=torch.float, device='cuda')
            #    partition = torch.tensor(partition, dtype=torch.long, device='cuda')
            #    partition_as_graph = torch.tensor(partition_as_graph, dtype=torch.float, device='cuda')
            self.sets.append(set_)
            self.partitions.append(partition)
            self.partitions_as_graphs.append(partition_as_graph)

        #if not torch.cuda.is_available():
        #    self.sets = np.array(self.sets,dtype=object)
        #    self.partitions = np.array(self.partitions,dtype=object)
        #    self.partitions_as_graphs = np.array(self.partitions_as_graphs,dtype=object)

        

    def __len__(self):
        
        return self.n_jets

    def get_all_items(self):
        node_feats = np.array([self.jet_arrays[x.encode()] for x in self.node_features_list])
        jet_feats = np.array([self.jet_arrays[x.encode()] for x in self.jet_features_list])
        n_labels = np.array(self.jet_arrays['trk_vtx_index'.encode()])

        for i in range(self.n_jets):
            n_nodes = self.n_nodes[i]
            node_feats_i = np.stack(node_feats[:, i], axis=0)  # shape (6, n_nodes)
            jet_feats_i = jet_feats[:, i]  # shape (4, )
            jet_feats_i = jet_feats_i[:, np.newaxis]  # shape (4, 1)

            node_feats_i = transform_features(FeatureTransform.node_feature_transform_list, node_feats_i)
            jet_feats_i = transform_features(FeatureTransform.jet_features_transform_list, jet_feats_i)

            jet_feats_i = np.repeat(jet_feats_i, n_nodes, axis=1)  # change shape to (4, n_nodes)
            set_i = np.concatenate([node_feats_i, jet_feats_i]).T  # shape (n_nodes, 10)

            partition_i = n_labels[i]

            sort_order = np.argsort(node_feats_i[4])
            set_i = set_i[sort_order]

            tile = np.tile(partition_i, (self.n_nodes[i], 1))
            partition_as_graph_i = np.where((tile - tile.T), 0, 1)

            yield set_i, partition_i, partition_as_graph_i

    def __getitem__(self, idx):
        
        return self.sets[idx], self.partitions[idx], self.partitions_as_graphs[idx]


class JetsBatchSampler(Sampler):
    def __init__(self, n_nodes_array, batch_size):
        """
        Initialization
        :param n_nodes_array: array of sizes of the jets
        :param batch_size: batch size
        """
        super().__init__(n_nodes_array.size)

        self.dataset_size = n_nodes_array.size
        self.batch_size = batch_size

        self.index_to_batch = {}
        self.node_size_idx = {}
        running_idx = -1

        for n_nodes_i in set(n_nodes_array):

            if n_nodes_i <= 1:
                continue
            self.node_size_idx[n_nodes_i] = np.where(n_nodes_array == n_nodes_i)[0]

            n_of_size = len(self.node_size_idx[n_nodes_i])
            n_batches = max(n_of_size / self.batch_size, 1)

            self.node_size_idx[n_nodes_i] = np.array_split(np.random.permutation(self.node_size_idx[n_nodes_i]),
                                                           n_batches)
            for batch in self.node_size_idx[n_nodes_i]:
                running_idx += 1
                self.index_to_batch[running_idx] = batch

        self.n_batches = running_idx + 1

    def __len__(self):
        return self.n_batches

    def __iter__(self):
        batch_order = np.random.permutation(np.arange(self.n_batches))
        for i in batch_order:
            yield self.index_to_batch[i]

def transform_features(transform_list, arr):
    new_arr = np.zeros_like(arr)
    for col_i, (mean, std) in enumerate(transform_list):
        new_arr[col_i, :] = (arr[col_i, :] - mean) / std
    return new_arr

class FeatureTransform(object):
    # Based on mean and std values of TRAINING set only
    node_feature_transform_list = [
        (0.0006078152, 14.128961),
        (0.0038490593, 10.688491),
        (-0.0026713554, 1.8167108),
        (0.0047640945, 1.889725),
        (5.237357, 7.4841413),
        (-0.00015662189, 1.0)]

    jet_features_transform_list = [
        (75.95093, 49.134453),
        (0.0022607117, 1.2152709),
        (-0.0023569583, 1.8164033),
        (9.437994, 6.765137)]

In [None]:
ds = JetGraphDataset('valid_data.root')

In [None]:
ds[0]

(array([[-0.0018891 , -0.00236357,  0.42327377, -0.21854699,  0.5066499 ,
          1.0001566 , -0.27195626, -0.3864688 ,  0.41081807, -0.53265476],
        [ 0.00198516, -0.00227709,  0.38614634, -0.24547258,  1.5638206 ,
         -0.99984336, -0.27195626, -0.3864688 ,  0.41081807, -0.53265476]],
       dtype=float32),
 array([0, 0], dtype=int32),
 array([[1, 1],
        [1, 1]]))

the dataset outputs a $n_{tracks} \times d_{in} $ array, a $n_{tracks}$ target array of vertex indices, and a $n\times n$ array of edge targets

In [None]:
batch_size = 10
batch_sampler = JetsBatchSampler(ds.n_nodes, batch_size)
data_loader = DataLoader(ds, batch_sampler=batch_sampler)

## the model

the model is comprised of three parts, the set function $\phi$, the broadcasting part $\beta$ and the output edge prediction $\psi$

$\phi$ is the DeepSet module

$\beta$ is implemented in the forward of the SetToGraph module

$\psi$ is the PsiSuffix module.


the SetToGraph module combined all three

In [None]:
import torch
import torch.nn as nn
import math


class DeepSet(nn.Module):
    def __init__(self, in_features, feats):
        super(DeepSet, self).__init__()
        
        layers = []
        
        layers.append(DeepSetLayer(in_features, feats[0]))
        for i in range(1, len(feats)):
            layers.append(nn.ReLU())
            layers.append(DeepSetLayer(feats[i-1], feats[i]))

        self.sequential = nn.Sequential(*layers)

    def forward(self, x):
        return self.sequential(x)


class DeepSetLayer(nn.Module):
    def __init__(self, in_features, out_features):
      
        super(DeepSetLayer, self).__init__()

        self.attention = Attention(in_features)
        self.layer1 = nn.Conv1d(in_features, out_features, 1)
        self.layer2 = nn.Conv1d(in_features, out_features, 1)


    def forward(self, x):
        # x.shape = (B,C,N)

        x_T = x.transpose(2, 1)  # B,C,N -> B,N,C
        x = self.layer1(x) + self.layer2(self.attention(x_T).transpose(1, 2))
       
        x = x / torch.norm(x, p='fro', dim=1, keepdim=True)  # BxCxN / Bx1xN

        return x

class Attention(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        small_in_features = max(math.floor(in_features/10), 1)
        self.d_k = small_in_features

        self.query = nn.Sequential(
            nn.Linear(in_features, small_in_features),
            nn.Tanh(),
        )
        self.key = nn.Linear(in_features, small_in_features)

    def forward(self, inp):
        # inp.shape should be (B,N,C)
        q = self.query(inp)  # (B,N,C/10)
        k = self.key(inp)  # B,N,C/10

        x = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.d_k)  # B,N,N

        x = x.transpose(1, 2)  # (B,N,N)
        x = x.softmax(dim=2)  # over rows
        x = torch.matmul(x, inp)  # (B, N, C)
        return x
    


In [None]:
class PsiSuffix(nn.Module):
    def __init__(self, features):
        super().__init__()
        layers = []
        
        for i in range(len(features) - 2):
            layers.append(nn.Conv2d(features[i], features[i + 1],1))
            layers.append(nn.ReLU())
        layers.append(nn.Conv2d(features[-2], features[-1], 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    

class SetToGraph(nn.Module):
    def __init__(self, in_features, out_features, set_fn_feats, hidden_mlp):

        super(SetToGraph, self).__init__()
        

        self.set_model = DeepSet(in_features=in_features, feats=set_fn_feats)
      

        # Suffix - from last number of features, to 1 feature per entrance
        d2 = 3*set_fn_feats[-1]
        hidden_mlp = [d2] + hidden_mlp + [out_features]
        self.suffix = PsiSuffix(hidden_mlp)

    def forward(self, x):
        x = x.transpose(2, 1)  # from BxNxC to BxCxN
        u = self.set_model(x)  # Bx(out_features)xN
        n = u.shape[2]

       
        
        m1 = u.unsqueeze(2).repeat(1, 1, n, 1)  # broadcast to rows
        m2 = u.unsqueeze(3).repeat(1, 1, 1, n)  # broadcast to cols
        m3 = torch.sum(u, dim=2, keepdim=True).unsqueeze(3).repeat(1, 1, n, n)  # sum over N, put on all

        block = torch.cat((m1, m2, m3), dim=1)
        edge_vals = self.suffix(block)  # shape (B,out_features,N,N)

        return edge_vals.squeeze(1)




In [None]:
model = SetToGraph(10,1,[256,256,256],[256])

In [None]:
ds[34][0].shape

(14, 10)

### Putting one element from the dataset through the model:

In [None]:
model( torch.tensor([ds[34][0]]) ).shape

torch.Size([1, 14, 14])

### basic training loop

In [None]:
import torch.nn.functional as F
def loss_func(y_hat, y):
    
    # No loss on diagonal, so set diagonal elements to 1
    B, N, _ = y_hat.shape
    y_hat[:, torch.arange(N), torch.arange(N)] = torch.finfo(y_hat.dtype).max  # to be "1" after sigmoid

    # calc loss
    loss = F.binary_cross_entropy_with_logits(y_hat, y)  # cross entropy

    y_hat = torch.sigmoid(y_hat)
    tp = (y_hat * y).sum(dim=(1, 2))
    fn = ((1. - y_hat) * y).sum(dim=(1, 2))
    fp = (y_hat * (1. - y)).sum(dim=(1, 2))
    loss = loss - ((2 * tp) / (2 * tp + fp + fn + 1e-10)).sum()  # fscore

    return loss

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

for epoch in range(1):
    for graph, _, edge_target in data_loader:
    
        edge_prediction = model(graph)
    
        loss = loss_func(edge_prediction,edge_target.float())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        break
    print("Epoch:",epoch)
    
    

Epoch: 0
