# Introduction

# Loading Data

The dataset consists of detection *events*, wherein a variable number of particles come into contact with the detector's cells. The location of these impacts are called *hits*, and there are likewise a variable number of them per particle.

In [1]:
import pickle
import numpy as np

with open('tracks.pickle', 'rb') as f:
    samples = pickle.load(f)

print("Loaded {} samples.".format(len(samples)))

Loaded 100 samples.


### Features

Each sample contains a set of hits, and each hit contains the following information:

* *x,y,z* coordinates
* Cell count and impact magnitude
* A learned hit embedding, output from the previous graph creation stage
* Ground truth cluster ID, denoting the particle which created the hit

Additionally, samples contain graphs as output from the previous stage which aims to connect hits created by the same particle. The two graphs included are

* A predicted graph, the raw output from the graph building stage
* An augmented graph, which contains the predicted graph, plus any connections missed between hits created by the same particle. This is used in the GNN's loss function.

### Visualizations

Choosing a sample to explore, one can see how the embedding differs from the raw features for graph creation.

In [2]:
import matplotlib.pyplot as plt

def plot_clusters(x,y,pid):
    for g in np.unique(pid):
        i = np.where(pid == g)
        plt.scatter(x[i],y[i], label=g)
    plt.show()
    plt.clf()

hits = samples[0]['hits']
xyz = hits['xyz']
emb = hits['emb']
pid = hits['particle_id']

# Hit coordinates
plot_clusters(xyz[:,0], xyz[:,1], pid)

# Emb coordinates
plot_clusters(emb[:,0], emb[:,1], pid)

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

Clearly, the embedding will lead to superior clustering as compared with the raw *x,y,z* positions. However, this embedding incorporates information from only each hit individually. With a GNN, one can create node embeddings which incorporate information from the hit's neighborhood. As we will see, this allows for superior embeddings and thus improved performance in clustering.

# Model

The GNN model chosen is a simple message-passing architecture. One layer concatenates each node's features with an aggregation of the node's neighborhood, before applying a transformation via a fully-connected neural network layer.

The output of the model is a set of node embeddings, where this new embedding has the same goal as in the graph building stage: according to some distance metric, node pairs whose hits belong to the same particle should be close, and otherwise they should be far.

In [3]:
import dgl
import torch
import torch.nn as nn

class GNN(nn.Module):
  def __init__(self,
               nb_hidden_gnn,
               nb_layer,
               nb_hidden_kernel,
               nb_kernel,
               input_dim,
               emb_dim=2):
    super(GNN, self).__init__()

    # Construct first layer
    gnn_layers = [GNN_Layer(input_dim,
                            nb_hidden_gnn,
                            nb_kernel,
                            nb_hidden_kernel,
                            apply_norm=True,
                            softmax=False)]
    
    # Construct additional layers
    for _ in range(nb_layer-1):
        gnn_layers.append(GNN_Layer(nb_hidden_gnn,
                                    nb_hidden_gnn,
                                    nb_kernel,
                                    nb_hidden_kernel))
    self.layers = nn.ModuleList(gnn_layers)

    self.final_emb = nn.Linear(nb_hidden_gnn, emb_dim)

  def forward(self, g):
    if torch.cuda.is_available():
        g.ndata['feat'] = g.ndata.pop('feat').to('cuda', non_blocking=True)

    emb = g.ndata.pop('feat')
    for i, layer in enumerate(self.layers):
      emb = layer(g, emb)
    emb = self.final_emb(emb)
    return emb

def weighted_msg(e):
    return {'msg': e.src['feat'] * e.data['e']}


class GNN_Layer(nn.Module):
    def __init__(self,
               input_dim,
               nb_hidden_gnn,
               nb_kernel,
               nb_hidden_kernel,
               apply_norm=True,
               softmax=False):
        super(GNN_Layer, self).__init__()

        if softmax:
            self.kernel = MLP_Kernel_DGL_Softmax(input_dim, nb_hidden_kernel)
        else:
            self.kernel = MLP_Kernel_DGL(input_dim, nb_hidden_kernel)

        self.gconv = DGL_Convolution(input_dim, nb_hidden_gnn)
        self.bn = nn.BatchNorm1d(input_dim,momentum=0.10) if apply_norm else None

    def forward(self, g, features):
        # maybe apply normalization
        if self.bn is not None:
            features = self.bn(features)
        g.ndata['feat'] = features

        # set edge weights for this layer
        g = self.kernel(g)
        
        # send weighted messages and apply graph convolution to nodes
        g.send_and_recv(g.edges(),
                        message_func=weighted_msg,
                        reduce_func=dgl.function.sum(msg='msg', out='agg_msg'),
                        apply_node_func=self.gconv)
        g.ndata.pop('feat')
        g.ndata.pop('agg_msg')
        return g.ndata.pop('emb')

class DGL_Convolution(nn.Module):
    def __init__(self,
               input_dim,
               nb_hidden_gnn):
        super(DGL_Convolution, self).__init__()
        self.weights = nn.Linear(2*input_dim, nb_hidden_gnn)
        self.act = nn.ReLU()

    def forward(self, n):
        feats = n.data['feat']
        agg_msg = n.data['agg_msg']
        node_feats = torch.cat((feats, agg_msg), dim=1)
        emb = self.weights(node_feats)
        emb = self.act(emb)
        return {'emb':emb}


class MLP_Kernel_DGL(nn.Module):
    def __init__(self, nb_input, nb_hidden_gnn, nb_output=1, nb_layer=1):
        super(MLP_Kernel_DGL, self).__init__()
        layers = [nn.Linear(nb_input*2, nb_hidden_gnn)]
        for _ in range(nb_layer-1):
            layers.append(nn.Linear(nb_hidden_gnn, nb_hidden_gnn))
        layers.append(nn.Linear(nb_hidden_gnn, nb_output))
        self.layers = nn.ModuleList(layers)
        self.act1 = nn.ReLU()
        self.act2 = nn.Sigmoid()

    def forward(self, g):
        g.apply_edges(self.mlp)
        return g

    def mlp(self, e):
        # Gather features from all relevant node pairs
        src = e.src['feat']
        dst = e.dst['feat']
        e_feats = torch.cat((src,dst),dim=1)
        
        # Apply MLP layers to node pairs
        for l in self.layers[:-1]:
            e_feats = self.act1(l(e_feats))
        
        # Apply final output with sigmoid
        e_feats = self.layers[-1](e_feats)
        e_feats = self.act2(e_feats)
        return {'e':e_feats}

# Dataset, Dataloader

In [18]:
from torch.utils.data import Dataset

def get_edge_indices(edges):
    edge_pairs = []
    for i, neighbors in enumerate(edges):
        for e_idx in neighbors:
            edge_pairs.append([i,e_idx])
    return edge_pairs

def get_true_edge_values(pred_edge_idx, true_edges):
    values = [0] * len(pred_edge_idx)
    for i, (src, dst) in enumerate(pred_edge_idx):
        if dst in true_edges[src]:
            values[i] = 1
    return values

class TrackML_Dataset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __getitem__(self, index):
        s = self.samples[index]
        
        hits = s['hits']
        xyz  = hits['xyz']
        emb  = hits['emb']
        hits = torch.FloatTensor(np.concatenate((xyz, emb), axis=1))

        graphs = s['graphs']
        pred_edges = graphs['pred']
        loss_edges = graphs['loss']
        true_edges = graphs['true']
    
        pred_edge_idx = get_edge_indices(pred_edges)
        true_edge_idx = get_edge_indices(loss_edges)
        true_edge_values = get_true_edge_values(true_edge_idx,true_edges)

        # Build inference graph
        g_input = dgl.DGLGraph()
        g_input.add_nodes(len(hits))
        src, dst = tuple(zip(*pred_edge_idx))
        g_input.add_edges(src, dst)
        g_input.ndata['feat'] = hits

        # Build ground truth graph
        g_true = dgl.DGLGraph()
        g_true.add_nodes(len(hits))
        src, dst = tuple(zip(*true_edge_idx))
        g_true.add_edges(src, dst)
        g_true.edata['truth'] = torch.FloatTensor(true_edge_values)
        
        g_input.set_n_initializer(dgl.init.zero_initializer)
        g_true.set_n_initializer(dgl.init.zero_initializer)
        g_input.set_e_initializer(dgl.init.zero_initializer)
        g_true.set_e_initializer(dgl.init.zero_initializer)
        
        return g_input, g_true
    
    def __len__(self):
        return len(self.samples)
    
def trackml_collate(sample):
    g_input = [s[0] for s in sample]
    g_input = dgl.batch(g_input)

    g_true = [s[1] for s in sample]
    g_true = dgl.batch(g_true)

    return g_input, g_true

# Training

## Setup

In [19]:
import torch.nn.functional as F
from torch.utils.data import DataLoader

# PARAMETERS
batch_size = 4
nb_hidden = 16
nb_layers = 4
learn_rate = 0.001

dataset = TrackML_Dataset(samples)
dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        collate_fn=trackml_collate, 
                        drop_last=True, 
                        shuffle=True)

net = GNN(nb_hidden, nb_layers, nb_hidden, 1, 6)
optim = torch.optim.Adamax(net.parameters(), lr=learn_rate)
print(net)

GNN(
  (layers): ModuleList(
    (0): GNN_Layer(
      (kernel): MLP_Kernel_DGL(
        (layers): ModuleList(
          (0): Linear(in_features=12, out_features=16, bias=True)
          (1): Linear(in_features=16, out_features=1, bias=True)
        )
        (act1): ReLU()
        (act2): Sigmoid()
      )
      (gconv): DGL_Convolution(
        (weights): Linear(in_features=12, out_features=16, bias=True)
        (act): ReLU()
      )
      (bn): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): GNN_Layer(
      (kernel): MLP_Kernel_DGL(
        (layers): ModuleList(
          (0): Linear(in_features=32, out_features=16, bias=True)
          (1): Linear(in_features=16, out_features=1, bias=True)
        )
        (act1): ReLU()
        (act2): Sigmoid()
      )
      (gconv): DGL_Convolution(
        (weights): Linear(in_features=32, out_features=16, bias=True)
        (act): ReLU()
      )
      (bn): BatchNorm1d(16, eps=1e-05, momentum=0.

In [20]:
def get_emb_for_loss(e):
    src = e.src['emb']
    dst = e.dst['emb']
    truth = e.data['truth']
    pred_dst = nn.functional.pairwise_distance(src, dst)
    true_dst = truth*2 -1
    loss = nn.functional.hinge_embedding_loss(pred_dst, true_dst, reduction='none')
    return {'loss':loss, 'pred_dst':pred_dst, 'true_dst':true_dst}

def score_dist_accuracy(pred, true):
    pred = pred.round()
    pred[pred!=0] = 1
    pred = 1-pred
    correct = pred==true
    nb_correct = correct.sum()
    nb_total = true.size(0)
    score = float(nb_correct.item()) / nb_total
    return score

def train_one_epoch(net, batch_size, optimizer, train_loader):
    net.train()

    nb_batch = len(train_loader)
    nb_train = nb_batch * batch_size
    epoch_score = 0
    epoch_loss  = 0

    print("\nTraining on {} samples".format(nb_train))
    for i, (g_input, g_true) in enumerate(train_loader):
        optimizer.zero_grad()
        
        f = g_input.ndata['feat']

        hits_emb = net(g_input)
        g_true.ndata['emb'] = hits_emb
        

        g_true.apply_edges(get_emb_for_loss)

        loss = g_true.edata.pop('loss').mean()
        score = score_dist_accuracy(g_true.edata.pop('pred_dst'),
                                          g_true.edata.pop('truth'))

        loss.backward()
        optimizer.step()

        epoch_score += score * 100
        epoch_loss  += loss.item()

        nb_proc = (i+1) * batch_size
        if (((i+1) % (nb_batch//2)) == 0):
            print("  {:2d}  Loss: {:.3f}  Acc: {:2.1f}".format(nb_proc, epoch_loss/(i+1), epoch_score/(i+1)))
    return epoch_loss / nb_batch, epoch_score / nb_batch

for i in range(10):
    train_one_epoch(net, batch_size, optim, dataloader)


Training on 100 samples




  48  Loss: 0.288  Acc: 70.5
  96  Loss: 0.245  Acc: 75.9

Training on 100 samples
  48  Loss: 0.161  Acc: 85.1
  96  Loss: 0.139  Acc: 87.9

Training on 100 samples
  48  Loss: 0.111  Acc: 90.9
  96  Loss: 0.105  Acc: 91.0

Training on 100 samples
  48  Loss: 0.096  Acc: 92.5
  96  Loss: 0.092  Acc: 92.9

Training on 100 samples
  48  Loss: 0.085  Acc: 93.0
  96  Loss: 0.081  Acc: 93.4

Training on 100 samples
  48  Loss: 0.077  Acc: 94.2
  96  Loss: 0.078  Acc: 94.0

Training on 100 samples
  48  Loss: 0.066  Acc: 94.7
  96  Loss: 0.071  Acc: 94.3

Training on 100 samples
  48  Loss: 0.071  Acc: 95.0
  96  Loss: 0.068  Acc: 95.1

Training on 100 samples
  48  Loss: 0.061  Acc: 95.6
  96  Loss: 0.064  Acc: 95.3

Training on 100 samples
  48  Loss: 0.058  Acc: 95.9
  96  Loss: 0.059  Acc: 95.8


# Clustering

In [21]:
# Embed samples
embeddings = []
net.eval()
with torch.autograd.no_grad():
    for i, (g_input, g_true) in enumerate(dataloader):
        f = g_input.ndata['feat']
        hits_emb = net(g_input)
        embeddings.append(hits_emb.numpy())

In [22]:
# Cluster
from sklearn.cluster import DBSCAN

c = DBSCAN(eps=.1, min_samples=3)
clusters = c.fit_predict(embeddings[0])
print(clusters)

[ 0  0  0  0  1  0  2  3  4  1  5  5  0  2  6  3 -1  2  6  3  3  4  6  3
  3  4  5  5  5  5  5  5  5  5  5  5  1  1  0  1  1  0  1  0 -1  0  2  2
  0  2  2  6  3 -1  6  3  4  3  3  3  3  4  4  4  1  4  5  5  3  4  4  7
  7  7  7  7  7  7  7  7  7  7  8  8  9  9 -1 10 11 12 13 14 15  8 -1 -1
 10 13 12 11  8  9 16 16 10 -1 13 11 12 12  8 16  9 10 13 12 -1 14 14 15
 15 17 17 17 14 15 15 17 14 15 15 17 14 15 15 17 15 15 17 17 17  8  8  8
  9 16 16 10 13 12 12 -1  8 16  9 16 13 10 12 -1  9  9 16 13 10 12 16  9
 13 10 11 14 11 11 14 15 14 15 15 14 15 14 15 17 16  9 13 10 16  9 13 10
 10  1 18 18 18 18 18  1 18 18 18 18 18 19 20 21 22 23 23 19 -1 20 21 22
 23 19 20 21 22 -1 22 23 24 24 23 23 24 24 23 23 24 24 23 24 24 23 24 24
 24 24 18 18 19 19 20 19 20 19 20 19 20 20 22 20 22 22 23 23 23 23 19 19
 19 25 25 19 25 26 26 26 26 27 26 26 27 26 26 26 27 26 26 26 27 27 27 27
 27 27 28 29 30 31 32 28 28 29 30 31 -1 33 29 29 30 32 32 31 29 30 32 31
 27 27 33 27 33  1 33 33 -1 33 29 29 30 32 31 29 30