<a href="https://colab.research.google.com/github/exatrkx/exatrkx-iml2020/blob/main/notebooks/WalkThroughGNN4Tracking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Last update: 2020-11-25

Changing log:
* 2020-11-25: manually install pytorch 1.6.0.

# Preparation

Following steps are to install dependent libraries. After each `pip install`, please restart the runtime and continue to next cell. Once everything is installed, it is ready to run the rest of the cells.

In [1]:
import torch
import tensorflow as tf
print(torch.__version__)
print(tf.__version__)

1.6.0+cu101
2.3.0


In [2]:
!python --version
!pip --version

Python 3.8.5
pip 21.0.1 from /home/zachary/anaconda3/envs/exatrkx/lib/python3.8/site-packages/pip (python 3.8)


The torch geometric was using pytorch 1.6.0, so reinstall pytorch

In [3]:
!pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [4]:
!pip install git+https://github.com/exatrkx/exatrkx-iml2020.git@v1.1.1

Collecting git+https://github.com/exatrkx/exatrkx-iml2020.git@v1.1.1
  Cloning https://github.com/exatrkx/exatrkx-iml2020.git (to revision v1.1.1) to /tmp/pip-req-build-gwq05lxv
  Running command git clone -q https://github.com/exatrkx/exatrkx-iml2020.git /tmp/pip-req-build-gwq05lxv
  Running command git checkout -q d0ad6bde13132cd458d8ee63d1c3d6a23a1bf5f4
Collecting trackml@ https://github.com/LAL/trackml-library/tarball/master#egg=trackml-3
  Using cached https://github.com/LAL/trackml-library/tarball/master


In [None]:
!install_geometric.sh

In [7]:
!pip install mpi4py



Restart the runtime manually, otherwise Horovod could not be installed.

In [8]:
!gcc --version

gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.



In [2]:
!HOROVOD_WITH_MPI=1 pip install horovod --no-cache-dir

Collecting horovod
  Downloading horovod-0.21.3.tar.gz (3.2 MB)
[K     |████████████████████████████████| 3.2 MB 1.1 MB/s eta 0:00:01
Building wheels for collected packages: horovod
  Building wheel for horovod (setup.py) ... [?25ldone
[?25h  Created wheel for horovod: filename=horovod-0.21.3-cp38-cp38-linux_x86_64.whl size=37878115 sha256=75440648c624e65f1d76764fb54df104146869bc31d0e8d9ba4fcacb778c3d24
  Stored in directory: /tmp/pip-ephem-wheel-cache-dyr_6cil/wheels/44/59/88/c9cf522ef11b3ce26e2a40f12bc79b2809a475d2d9c38d83a3
Successfully built horovod
Installing collected packages: horovod
Successfully installed horovod-0.21.3


Restart the runtime again to let new packages take effect.

# Prepare input data

In [3]:
%%bash
wget https://portal.nersc.gov/project/atlas/xju/train_10evts.tar
tar xzf train_10evts.tar
rm train_10evts.tar

--2021-02-18 10:44:33--  https://portal.nersc.gov/project/atlas/xju/train_10evts.tar
Resolving portal.nersc.gov (portal.nersc.gov)... 128.55.206.28, 128.55.206.24
Connecting to portal.nersc.gov (portal.nersc.gov)|128.55.206.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94582666 (90M) [application/x-tar]
Saving to: ‘train_10evts.tar’

     0K .......... .......... .......... .......... ..........  0%  308K 5m0s
    50K .......... .......... .......... .......... ..........  0%  269K 5m22s
   100K .......... .......... .......... .......... ..........  0%  411K 4m49s
   150K .......... .......... .......... .......... ..........  0%  506K 4m22s
   200K .......... .......... .......... .......... ..........  0%  464K 4m9s
   250K .......... .......... .......... .......... ..........  0%  713K 3m49s
   300K .......... .......... .......... .......... ..........  0%  769K 3m34s
   350K .......... .......... .......... .......... ..........  0%  771K 3m22s
   

In [6]:
!ls

Untitled.ipynb		       detectors.csv	inference.py	 output
WalkThroughGNN4Tracking.ipynb  inference.ipynb	inference_fn.py  train_10evts


In [7]:
!ls train_10evts/

event000001000-cells.csv      event000001005-cells.csv
event000001000-hits.csv       event000001005-hits.csv
event000001000-particles.csv  event000001005-particles.csv
event000001000-truth.csv      event000001005-truth.csv
event000001001-cells.csv      event000001006-cells.csv
event000001001-hits.csv       event000001006-hits.csv
event000001001-particles.csv  event000001006-particles.csv
event000001001-truth.csv      event000001006-truth.csv
event000001002-cells.csv      event000001007-cells.csv
event000001002-hits.csv       event000001007-hits.csv
event000001002-particles.csv  event000001007-particles.csv
event000001002-truth.csv      event000001007-truth.csv
event000001003-cells.csv      event000001008-cells.csv
event000001003-hits.csv       event000001008-hits.csv
event000001003-particles.csv  event000001008-particles.csv
event000001003-truth.csv      event000001008-truth.csv
event000001004-cells.csv      event000001009-cells.csv
event000001004-hits.csv       event000001009-hits.csv

In [8]:
!pwd

/home/zachary/exatrkx-iml2020/notebooks


# Introduction

The code locates at [exatrkx-iml2020](https://github.com/exatrkx/exatrkx-iml2020). The code structure is the following:


```text
exatrkx
├── __init__.py
├── configs
│   ├── prepare_feature_store.yaml
│   ├── train_embedding.yaml
│   └── train_filter.yaml
├── scripts
│   ├── run_lightning.py
│   ├── convert2tf.py
│   ├── train_gnn_tf.py
│   ├── eval_gnn_tf.py
│   ├── tracks_from_gnn.py
│   ├── count_node_edges.py
│   └── install_geometric.sh
└── src
    ├── processing
    │   ├── __init__.py
    │   ├── cell_direction_utils/
    │   ├── feature_construction.py
    │   └── utils.py
    ├── embedding
    │   ├── __init__.py
    │   ├── embedding_base.py
    │   ├── layerless_embedding.py
    ├── filter
    │   ├── __init__.py
    │   ├── filter_base.py
    │   └── vanilla_filter.py
    ├── tfgraphs/
    ├── torchgnn/
    ├── utils_dir.py
    └── utils_torch.py
```

Everytime the runtime is restarted, the following two cells should be executed. The code requires two input directories via two global environment variables, and use the two variables to internally organizes output directories.

In [2]:
import os
os.environ['TRKXINPUTDIR'] = "train_10evts"
os.environ['TRKXOUTPUTDIR'] = "output"

In [3]:
# system import
import pkg_resources
import yaml
import pprint

# 3rd party
import torch
from trackml.dataset import load_event
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# local import
from exatrkx import config_dict # for accessing predefined configuration files
from exatrkx import outdir_dict # for accessing predefined output directories
from exatrkx.src import utils_dir

# for preprocessing
from exatrkx import FeatureStore

# for embedding
from exatrkx import LayerlessEmbedding
from exatrkx import EmbeddingInferenceCallback
# for filtering
from exatrkx import VanillaFilter
from exatrkx import FilterInferenceCallback

# Pre-processing

A quick recap of the input dataset from the [tracking ML challenge](https://www.kaggle.com/c/trackml-particle-identification).

In [7]:
hits, cell, particles, truth = load_event("train_10evts/event000001000")

In [8]:
hits.head(3)

Unnamed: 0,hit_id,x,y,z,volume_id,layer_id,module_id
0,1,-64.409897,-7.1637,-1502.5,7,2,1
1,2,-55.336102,0.635342,-1502.5,7,2,1
2,3,-83.830498,-1.14301,-1502.5,7,2,1


In [9]:
cell.head(3)

Unnamed: 0,hit_id,ch0,ch1,value
0,1,209,617,0.013832
1,1,210,617,0.079887
2,1,209,618,0.211723


In [10]:
particles.head(3)

Unnamed: 0,particle_id,vx,vy,vz,px,py,pz,q,nhits
0,4503668346847232,-0.009288,0.009861,-0.077879,-0.055269,0.323272,-0.203492,-1,8
1,4503737066323968,-0.009288,0.009861,-0.077879,-0.948125,0.470892,2.01006,1,11
2,4503805785800704,-0.009288,0.009861,-0.077879,-0.886484,0.105749,0.683881,-1,0


In [11]:
truth.head(3)

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight
0,1,0,-64.411598,-7.16412,-1502.5,250710.0,-149908.0,-956385.0,0.0
1,2,22525763437723648,-55.338501,0.630805,-1502.5,-0.570605,0.02839,-15.4922,1e-05
2,3,0,-83.828003,-1.14558,-1502.5,626295.0,-169767.0,-760877.0,0.0


In [12]:
import os
action = 'build'

config_file = pkg_resources.resource_filename(
                    "exatrkx",
                    os.path.join('configs', config_dict[action]))
with open(config_file) as f:
  b_config = yaml.load(f, Loader=yaml.FullLoader)

In [13]:
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(b_config)

{   'adjacent': True,
    'cell_information': True,
    'endcaps': False,
    'layerless': True,
    'layerwise': False,
    'n_files': 10,
    'n_tasks': 1,
    'n_workers': 2,
    'noise': False,
    'pt_min': 1}


In [14]:
b_config['endcaps'] = False
b_config['pt_min'] = 1.0 
b_config['n_workers'] = 1
b_config['cell_information'] = True
b_config['n_files'] = 3 
b_config['noise'] = False

In [15]:
preprocess_dm = FeatureStore(b_config)
preprocess_dm.prepare_data()

Loading detector...
Detector loaded.
Writing outputs to output/feature_store


In [16]:
feature_data = torch.load("output/feature_store/1000", map_location='cpu')

In [17]:
feature_data

Data(cell_data=[7837, 9], event_file="train_10evts/event000001000", hid=[7837], layerless_true_edges=[2, 7875], layers=[7837], pid=[7837], pt=[7837], weights=[7875], x=[7837, 3])

# Embedding

The embedding/filtering module is written in pytorch as a [LightningModule](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html) and trained by the [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html), heriting all APIs associated with the two.

The training/validation steps for embedding and filtering are defined in a base class and the neural network is abstracted in the base class while implemented in its derived class.

In [18]:
action = 'embedding'

config_file = pkg_resources.resource_filename(
                    "exatrkx",
                    os.path.join('configs', config_dict[action]))
with open(config_file) as f:
  e_config = yaml.load(f, Loader=yaml.FullLoader)

pp = pprint.PrettyPrinter(indent=4)
pp.pprint(e_config)

{   'adjacent': False,
    'clustering': 'build_edges',
    'emb_dim': 8,
    'emb_hidden': 512,
    'endcaps': True,
    'factor': 0.3,
    'in_channels': 12,
    'knn_train': 20,
    'knn_val': 100,
    'layerless': True,
    'layerwise': False,
    'lr': 0.002,
    'margin': 1,
    'n_workers': 1,
    'nb_layer': 6,
    'noise': False,
    'overwrite': True,
    'patience': 5,
    'pt_min': 0,
    'r_train': 1,
    'r_val': 1.0,
    'randomisation': 2,
    'regime': ['rp', 'hnm', 'ci'],
    'train_split': [8, 1, 1],
    'warmup': 500,
    'weight': 4}


The sum of `train_split` has to match the total number of events generated from the preprocessing.

In [19]:
e_config['train_split'] = [1, 1, 1]

In [20]:
e_model = LayerlessEmbedding(e_config)

e_checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath=os.path.join(utils_dir.embedding_outdir,'ckpt-{epoch:02d}-{val_loss:.2f}') ,
    save_top_k=3,
    mode='min')
e_callback_list = [EmbeddingInferenceCallback()]

In [21]:
e_trainer = Trainer(
    max_epochs = 2,
    limit_train_batches=1,
    limit_val_batches=1,
    callbacks=e_callback_list,
    gpus=0,
    checkpoint_callback=e_checkpoint_callback
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [22]:
e_trainer.fit(e_model)


  | Name      | Type       | Params
-----------------------------------------
0 | layers    | ModuleList | 1 M   
1 | emb_layer | Linear     | 4 K   
2 | norm      | LayerNorm  | 1 K   
3 | act       | Tanh       | 0     


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Training finished, running inference to build graphs...
66.7% inference complete 

1

In [23]:
!ls output/embedding_output/train/

1002


In [24]:
embed_outfile = os.path.join(utils_dir.embedding_outdir, "train", "1002")
dd = torch.load(embed_outfile)
dd

Data(cell_data=[8729, 9], e_radius=[2, 410893], event_file="train_10evts/event000001002", hid=[8729], layerless_true_edges=[2, 8863], layers=[8729], pid=[8729], pt=[8729], weights=[8863], x=[8729, 3], y=[410893])

The above steps can be executed via a command line 

`run_lightning.py --action embedding --max_epochs 2 --gpus 0 --limit_train_batches 1 --limit_val_batches 1`

We walkthrough the training step for embedding learning.

```python
def training_step(self, batch, batch_idx):

    # apply the embedding neural network on the hit features
    # and return hidden features in the embedding space.
    spatial = self(torch.cat([batch.cell_data, batch.x], axis=-1))

    # create another direction for true doublets
    # doublets are also called edges
    e_bidir = torch.cat([batch.layerless_true_edges,
                        torch.stack([batch.layerless_true_edges[1],
                                    batch.layerless_true_edges[0]], axis=1).T
                        ], axis=-1)

    # engineering true and false doublets for each batch in that, 
    # [0, :] are reference hits, and [1, :] neighbor hits
    # two method used: random sampling (rp) and hard negative mining (hnm)
    # start from an empty batch: 
    e_spatial = torch.empty([2,0], dtype=torch.int64, device=self.device)

    if 'rp' in self.hparams["regime"]:
        # randomly select two times of total true edges
        n_random = int(self.hparams["randomisation"]*e_bidir.shape[1])
        e_spatial = torch.cat([e_spatial,
            torch.randint(e_bidir.min(), e_bidir.max(), (2, n_random), device=self.device)], axis=-1)

    # use a clustering algorithm to connect hits based on embedding information
    # euclidean distance is used.
    # r_train: radius for training, typical value, 1.
    # knn_train: the maximum number of neighbours, typical value, 20
    if 'hnm' in self.hparams["regime"]:
        e_spatial = torch.cat([e_spatial,
                        self.clustering(spatial, self.hparams["r_train"], self.hparams["knn_train"])], axis=-1)

    # label the engineered doublets according to the truth
    # sparse reprsentation is used to reduce executation time
    e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)

    # add all truth edges "weight" times, which is 4,
    # to balance the number of truth and fake edges in one batch
    e_spatial = torch.cat([
        e_spatial,
        e_bidir.transpose(0,1).repeat(1,self.hparams["weight"]).view(-1, 2).transpose(0,1)
        ], axis=-1)
    y_cluster = np.concatenate([y_cluster.astype(int), np.ones(e_bidir.shape[1]*self.hparams["weight"])])

    hinge = torch.from_numpy(y_cluster).float().to(device)
    hinge[hinge == 0] = -1

    # extract emedding features of reference hits and neighbor hits
    reference = spatial.index_select(0, e_spatial[1])
    neighbors = spatial.index_select(0, e_spatial[0])
    d = torch.sum((reference - neighbors)**2, dim=-1)

    loss = torch.nn.functional.hinge_embedding_loss(d, hinge, margin=self.hparams["margin"], reduction="mean")

    self.log("train_loss", loss, prog_bar=True)

    return loss
```

# Filtering

In [25]:
action = 'filtering'

config_file = pkg_resources.resource_filename(
                    "exatrkx",
                    os.path.join('configs', config_dict[action]))
with open(config_file) as f:
  f_config = yaml.load(f, Loader=yaml.FullLoader)

pp.pprint(f_config)

{   'adjacent': False,
    'batchnorm': False,
    'clustering': 'build_edges',
    'datatype_names': ['train', 'val', 'test'],
    'emb_channels': 0,
    'endcaps': True,
    'factor': 0.3,
    'filter_cut': 0.3,
    'hidden': 512,
    'in_channels': 12,
    'layerless': True,
    'layernorm': True,
    'layerwise': False,
    'lr': 0.002,
    'nb_layer': 3,
    'noise': False,
    'patience': 8,
    'pt_min': 0,
    'ratio': 2,
    'regime': ['ci'],
    'train_split': [8, 1, 1],
    'val_subset': 0.1,
    'warmup': 200,
    'weight': 2}


In [26]:
f_config['train_split'] = [1,1, 1]

In [27]:
f_model = VanillaFilter(f_config)
f_callback_list = [FilterInferenceCallback()]

In [28]:
f_trainer = Trainer(
    max_epochs = 2,
    limit_train_batches=1,
    limit_val_batches=1,
    callbacks=f_callback_list,
    gpus=0,
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [29]:
f_trainer.fit(f_model)


  | Name         | Type        | Params
---------------------------------------------
0 | input_layer  | Linear      | 12 K  
1 | layers       | ModuleList  | 525 K 
2 | output_layer | Linear      | 513   
3 | layernorm    | LayerNorm   | 1 K   
4 | batchnorm    | BatchNorm1d | 1 K   
5 | act          | Tanh        | 0     


Validation sanity check: 0it [00:00, ?it/s]

  'eff': torch.tensor(edge_true_positive/edge_true),
  'pur': torch.tensor(edge_true_positive/edge_positive)})


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Training finished, running inference to filter graphs...
66.7% inference complete 

1

In [30]:
filter_outfile = os.path.join(utils_dir.filtering_outdir, "train", "1002")
dd = torch.load(filter_outfile)
dd

Data(cell_data=[8729, 9], e_radius=[2, 57167], event_file="train_10evts/event000001002", hid=[8729], layerless_true_edges=[2, 8863], layers=[8729], pid=[8729], pt=[8729], weights=[8863], x=[8729, 3], y=[57167], y_pid=[57167])

# Graph Neural Network

## Prepare `tf.data` for training GNN
We converted the output to `tf.data`, [link to official guidence](https://www.tensorflow.org/guide/data#batching_dataset_elements), to take advantage of the TensorFlow input pipelines. By doing this, the data processing step takes a negligible contribution to the training as seen in a profiling results from TensorBorad. The actuall implementation can be found at [dataset.py](https://github.com/exatrkx/exatrkx-iml2020/blob/main/exatrkx/src/tfgraphs/dataset.py)

The executable that converts input dataset to `tf.data` is [convert2tf.py](https://github.com/exatrkx/exatrkx-iml2020/blob/main/exatrkx/scripts/convert2tf.py).

In [31]:
!convert2tf.py --edge-name "e_radius" --truth-name "y_pid"

processing files in folder: output/filtering_output/train
DoubletsDataset added 0 events, in 0.0 mins
processing files in folder: output/filtering_output/val
DoubletsDataset added 0 events, in 0.0 mins
processing files in folder: output/filtering_output/test
DoubletsDataset added 0 events, in 0.0 mins


## SegmentClassifier (GNN for doublet classification)
The core part of the classifier is the [interaction network](https://arxiv.org/abs/1612.00222) with small modifications, implemented with the [graph_nets](https://github.com/deepmind/graph_nets), explained below.

Modifications include 1) updating nodes uses both sender edges and receiver edges; 2) edge block is updated after the node-block is updated.

The training script is [train_gnn_tf.py](https://github.com/exatrkx/exatrkx-iml2020/blob/main/exatrkx/scripts/train_gnn_tf.py). We will first walk through the graph neural network and the training step, then run the training.

```python
class InteractionNetwork(snt.Module):
  """Implementation of an Interaction Network.

  An interaction networks computes interactions on the edges based on the
  previous edges features, and on the features of the nodes sending into those
  edges. It then updates the nodes based on the incomming updated edges.
  See https://arxiv.org/abs/1612.00222 for more details.

  This model does not update the graph globals, and they are allowed to be
  `None`.
  """

  def __init__(self,
               edge_model_fn,
               node_model_fn,
               reducer=tf.math.unsorted_segment_sum,
               name="interaction_network"):
    """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
    super(InteractionNetwork, self).__init__(name=name)
    self._edge_block = blocks.EdgeBlock(
        edge_model_fn=edge_model_fn, use_globals=False)
    self._node_block = blocks.NodeBlock(
        node_model_fn=node_model_fn,
        use_received_edges=True,
        use_sent_edges=True,
        use_globals=False,
        received_edges_reducer=reducer)

  def __call__(self, graph):
    """Connects the InterationNetwork.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s. `graph.globals` can be
        `None`. The features of each node and edge of `graph` must be
        concatenable on the last axis (i.e., the shapes of `graph.nodes` and
        `graph.edges` must match but for their first and last axis).

    Returns:
      An output `graphs.GraphsTuple` with updated edges and nodes.

    Raises:
      ValueError: If any of `graph.nodes`, `graph.edges`, `graph.receivers` or
        `graph.senders` is `None`.
    """
    return self._edge_block(self._node_block(graph))
```

```python
class SegmentClassifier(snt.Module):

  def __init__(self, name="SegmentClassifier"):
    super(SegmentClassifier, self).__init__(name=name)

    # objective is to initialize node features by
    # transforming input node features via MLP.
    self._node_encoder_block = blocks.NodeBlock(
        node_model_fn=make_mlp_model,
        use_received_edges=False,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False,
        name='node_encoder_block'
    )

    # objective is to initialize edge features by
    # transforming aggregated neibouring node information via MLP.
    self._edge_block = blocks.EdgeBlock(
        edge_model_fn=make_mlp_model,
        use_edges=False,
        use_receiver_nodes=True,
        use_sender_nodes=True,
        use_globals=False,
        name='edge_encoder_block'
    )

    # reducer can be [sum, max, min, max, ...]
    self._core = InteractionNetwork(
        edge_model_fn=make_mlp_model,
        node_model_fn=make_mlp_model,
        reducer=tf.math.unsorted_segment_sum
    )

    # Transforms the outputs into appropriate shapes
    edge_output_size = 1
    edge_fn =lambda: snt.Sequential([
        snt.nets.MLP([edge_output_size],
                      activation=tf.nn.relu, # default is relu
                      name='edge_output'),
        tf.sigmoid])

    self._output_transform = modules.GraphIndependent(edge_fn, None, None)

  def __call__(self, input_op, num_processing_steps):
    # make a full fledged graph that has node and features
    latent = self._edge_block(self._node_encoder_block(input_op))
    latent0 = latent

    output_ops = []
    # message passing with skip connections
    for _ in range(num_processing_steps):
        core_input = utils_tf.concat([latent0, latent], axis=1)
        latent = self._core(core_input)
        output_ops.append(self._output_transform(latent))
    return output_ops
```

The training step

```python
@functools.partial(tf.function, input_signature=input_signature)
def train_step(inputs_tr, targets_tr, first_batch):
    print("Tracing update_step")
    print("inputs nodes", inputs_tr.nodes.shape)
    print("inputs edges", inputs_tr.edges.shape)
    print("input n_node", inputs_tr.n_node.shape)
    print(inputs_tr.nodes)
    with tf.GradientTape() as tape:
        outputs_tr = model(inputs_tr, num_processing_steps_tr)
        loss_ops_tr = create_loss_ops(targets_tr, outputs_tr)
        loss_op_tr = tf.math.reduce_sum(loss_ops_tr) / tf.constant(num_processing_steps_tr, dtype=tf.float32)

    # Horovod: add Horovod Distributed GradientTape.
    if args.distributed:
        tape = hvd.DistributedGradientTape(tape)

    gradients = tape.gradient(loss_op_tr, model.trainable_variables)
    # optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    optimizer.apply(gradients, model.trainable_variables)

    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    #
    # Note: broadcast should be done after the first gradient step to ensure optimizer
    # initialization.
    if args.distributed and first_batch:
        hvd.broadcast_variables(model.trainable_variables, root_rank=0)
        hvd.broadcast_variables(optimizer.variables, root_rank=0)

    return loss_op_tr
```

The `--train-files`, `--val-files` and `--output-dir` are optional. If they are not specified, internally organized directory will be used.

In [32]:
!train_gnn_tf.py --help

usage: train_gnn_tf.py [-h] [--train-files TRAIN_FILES]
                       [--val-files VAL_FILES] [--output-dir OUTPUT_DIR] [-d]
                       [--num-iters NUM_ITERS] [--learning-rate LEARNING_RATE]
                       [--max-epochs MAX_EPOCHS]
                       [--real-edge-weight REAL_EDGE_WEIGHT]
                       [--fake-edge-weight FAKE_EDGE_WEIGHT]
                       [-v {DEBUG,ERROR,FATAL,INFO,WARN}]

Train nx-graph with configurations

optional arguments:
  -h, --help            show this help message and exit
  --train-files TRAIN_FILES
                        input TF records for training
  --val-files VAL_FILES
                        input TF records for validation
  --output-dir OUTPUT_DIR
                        where the model and training info saved
  -d, --distributed     data distributed training
  --num-iters NUM_ITERS
                        number of message passing steps
  --learning-rate LEARNING_RATE
                        learing

In [33]:
!train_gnn_tf.py

not doing distributed
INFO:tensorflow:found 0 GPUs
INFO:tensorflow:Checkpoints and models saved at output/gnn_models
INFO:tensorflow:1 epochs with batch size 1
INFO:tensorflow:8 processing steps in the model
INFO:tensorflow:I am in hvd rank: 0 of  total 1 ranks
INFO:tensorflow:rank 0 has 1 training files and 1 evaluation files
nodes [None, 3] float32
edges [None, 1] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
nodes [None, 1] float32
edges [None] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
INFO:tensorflow:Loading latest checkpoint from: output/gnn_models
INFO:tensorflow:start epoch 0 on CPU
Tracing update_step
inputs nodes (None, 3)
inputs edges (None, 1)
input n_node (1,)
Tensor("inputs_tr:0", shape=(None, 3), dtype=float32)
Tracing update_step
inputs nodes (None, 3)
inputs edges (None, 1)
input n_node (1,)
Tensor("inputs_tr:0", shape=(None, 3), dtype=float3

In [34]:
!train_gnn_tf.py --max-epochs 5

not doing distributed
INFO:tensorflow:found 0 GPUs
INFO:tensorflow:Checkpoints and models saved at output/gnn_models
INFO:tensorflow:5 epochs with batch size 1
INFO:tensorflow:8 processing steps in the model
INFO:tensorflow:I am in hvd rank: 0 of  total 1 ranks
INFO:tensorflow:rank 0 has 1 training files and 1 evaluation files
nodes [None, 3] float32
edges [None, 1] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
nodes [None, 1] float32
edges [None] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
INFO:tensorflow:Loading latest checkpoint from: output/gnn_models
INFO:tensorflow:start epoch 0 on CPU
Tracing update_step
inputs nodes (None, 3)
inputs edges (None, 1)
input n_node (1,)
Tensor("inputs_tr:0", shape=(None, 3), dtype=float32)
Tracing update_step
inputs nodes (None, 3)
inputs edges (None, 1)
input n_node (1,)
Tensor("inputs_tr:0", shape=(None, 3), dtype=float3

## Evaluating GNN

In [35]:
!eval_gnn_tf.py --help

usage: eval_gnn_tf.py [-h] [--input-dir INPUT_DIR] [--output-dir OUTPUT_DIR]
                      [--model-dir MODEL_DIR] [--filter-dir FILTER_DIR]
                      [--num-iters NUM_ITERS] [--inspect] [--overwrite]
                      [--max-evts MAX_EVTS] [--datatype {train,val,test}]

Evaluate trained GNN model

optional arguments:
  -h, --help            show this help message and exit
  --input-dir INPUT_DIR
                        input directory
  --output-dir OUTPUT_DIR
                        output directory
  --model-dir MODEL_DIR
                        model directory
  --filter-dir FILTER_DIR
                        filtering file directory
  --num-iters NUM_ITERS
                        number of message passing steps
  --inspect             inspect intermediate results
  --overwrite           overwrite the output
  --max-evts MAX_EVTS   process maximum number of events
  --datatype {train,val,test}


In [36]:
!eval_gnn_tf.py

Input file names: ['output/gnn_inputs/test/1001']
In total 1 files
Process 1 events
nodes [None, 3] float32
edges [None, 1] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
nodes [None, 1] float32
edges [None] float32
receivers [None] int32
senders [None] int32
globals [1, 1] float32
n_node [1] int32
n_edge [1] int32
Find model: output/gnn_models
Loaded latest checkpoint from: output/gnn_models
processing event 1001
4,884 nodes
32,815 edges
Exception ignored in: <function _CheckpointRestoreCoordinatorDeleter.__del__ at 0x7fca889d9040>
Traceback (most recent call last):
  File "/home/zachary/anaconda3/envs/exatrkx/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py", line 146, in __del__
TypeError: 'NoneType' object is not callable


# Track labeling

In [37]:
def prepare(score, senders, receivers, n_nodes):
    # prepare the DBSCAN input, which the adjancy matrix with its value being the edge socre.
    e_csr = sp.sparse.csr_matrix( (score, (senders, receivers)), shape=(n_nodes, n_nodes), dtype=np.float32)
    # rescale the duplicated edges
    e_csr.data[e_csr.data > 1] = e_csr.data[e_csr.data > 1]/2.
    # invert to treat score as an inverse distance
    e_csr.data = 1 - e_csr.data
    # make it symmetric
    e_csr_bi = sp.sparse.coo_matrix((np.hstack([e_csr.tocoo().data, e_csr.tocoo().data]), 
                                    np.hstack([np.vstack([e_csr.tocoo().row, e_csr.tocoo().col]),                                                                   
                                                np.vstack([e_csr.tocoo().col, e_csr.tocoo().row])])))
    return e_csr_bi

def clustering(e_csr_bi, epsilon=5, min_samples=1):
    # dbscan clustering
    clustering = DBSCAN(eps=epsilon, metric='precomputed', min_samples=1).fit_predict(e_csr_bi)
    track_labels = np.vstack([np.unique(e_csr_bi.tocoo().row), clustering[np.unique(e_csr_bi.tocoo().row)]])
    track_labels = pd.DataFrame(track_labels.T)
    track_labels.columns = ["hit_id", "track_id"]
    new_hit_id = np.apply_along_axis(lambda x: used_hits[x], 0, track_labels.hit_id.values)
    tracks = pd.DataFrame.from_dict({"hit_id": new_hit_id, "track_id": track_labels.track_id})
    return tracks

In [38]:
!tracks_from_gnn.py --help

usage: tracks_from_gnn.py [-h] [--max-evts MAX_EVTS] [--input-dir INPUT_DIR]
                          [--output-dir OUTPUT_DIR]
                          [--datatype {train,val,test}]
                          [--edge-score-cut EDGE_SCORE_CUT]
                          [--epsilon EPSILON] [--min-samples MIN_SAMPLES]
                          [--min-num-hits MIN_NUM_HITS]

construct tracks from the input created by the evaluate_edge_classifier

optional arguments:
  -h, --help            show this help message and exit
  --max-evts MAX_EVTS   maximum number of events for testing
  --input-dir INPUT_DIR
                        input directory
  --output-dir OUTPUT_DIR
                        output file directory for track candidates
  --datatype {train,val,test}
  --edge-score-cut EDGE_SCORE_CUT
                        edge score cuts
  --epsilon EPSILON     epsilon in DBScan
  --min-samples MIN_SAMPLES
                        minimum number of samples in DBScan
  --min-num-hits MIN_NU

In [39]:
!tracks_from_gnn.py

total 1 files for test
Event 1001 has track ML score: 0.0170


# Conclusion
The track ML score is not good. A couple of handles can be tuned to get a good result. With more data for training, the Embedding, Filtering and GNN will be significantly improved. Some hyperparameters in each stage are also crucial. Now it is up to you to find the optimal parameters for problems in question.


Have Fun.

Exa.TrkX collaboration.