In [None]:
!pip install git+https://github.com/xju2/root_gnn.git@release2.0

### Creating graphs using networkx

[networkx](https://networkx.org/documentation/stable/tutorial.html) is a Python package for the study of graphs.


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import networkx as nx

from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets import graphs

In [None]:
g = nx.DiGraph()

# add nodes
[g.add_node(idx, features=np.array([1.*idx])) for idx in range(4)];

# add edges
edge_lists = [(0, 1), (1, 2), (2, 3), (3, 0)]
[g.add_edge(i, j, features=np.array([abs(i-j)])) for i,j in edge_lists];

In [None]:
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(g)
nx.draw(g, pos, node_size=400, alpha=0.85, node_color="#1f78b4", with_labels=True)

obtain the adjacency matrix

In [None]:
adj = np.asarray(nx.to_numpy_matrix(g))
adj

In [None]:
g.edges()

In [None]:
g_tuple = utils_np.networkxs_to_graphs_tuple([g])

In [None]:
g_tuple

In [None]:
def print_graphs_tuple(g, data=True):
    for field_name in graphs.ALL_FIELDS:
        per_replica_sample = getattr(g, field_name)
        if per_replica_sample is None:
            print(field_name, "EMPTY")
        else:
            print(field_name, "is with shape", per_replica_sample.shape)
            if data and  field_name != "edges":
                print(per_replica_sample)

In [None]:
print_graphs_tuple(g_tuple)

### Create GraphsTuple using data-dict \[recommend\]

In [None]:
n_node = 4
n_node_features = 1
n_edge = 4
n_edge_features = 1
nodes = np.random.rand(n_node, n_node_features).astype(np.float32)
edges = np.random.rand(n_edge, n_edge_features).astype(np.float32)
receivers = np.array([1, 2, 3, 0])
senders = np.array([0, 1, 2, 3])
datadict = {
    "n_node": n_node,
    "n_edge": n_edge,
    "nodes": nodes,
    "edges": edges,
    "senders": senders,
    "receivers": receivers,
    "globals": np.array([0], dtype=np.float32)
}

In [None]:
g_tuple2 = utils_tf.data_dicts_to_graphs_tuple([datadict])

In [None]:
print_graphs_tuple(g_tuple2)

### Can you finish implementing the following function?

In [None]:
def fully_connected_edges(n_nodes: int):
    """For a given number of nodes, 
    return the senders and receivers for a fully-connected graph.
    """
    
    receivers = senders = n_edge = None
    
    return {"receivers": receivers, "senders": senders, "n_edge": n_edge}

### Convert an event to a fully-connected graph

In [None]:
filename = '/global/homes/x/xju/atlas/data/top-tagger/test.h5'

In [None]:
with pd.HDFStore(filename, mode='r') as store:
    df = store['table']

In [None]:
df.head()

In [None]:
df[df['is_signal_new'] == 1].head()

In [None]:
event = df.iloc[0]
event

In [None]:
import itertools
from typing import Optional

features = ['E', 'PX', 'PY', 'PZ']
scale = 0.001
solution = 'is_signal_new'

def make_graph(event, debug: Optional[bool] = False):
    n_max_nodes = 200
    n_nodes = 0
    nodes = []
    for inode in range(n_max_nodes):
        E_name = 'E_{}'.format(inode)
        if event[E_name] < 0.1:
            continue

        f_keynames = ['{}_{}'.format(x, inode) for x in features]
        n_nodes += 1
        nodes.append(event[f_keynames].values*scale)
    nodes = np.array(nodes, dtype=np.float32)
    # print(n_nodes, "nodes")
    # print("node features:", nodes.shape)

    # edges 1) fully connected, 2) objects nearby in eta/phi are connected
    # TODO: implement 2). <xju>
    all_edges = list(itertools.combinations(range(n_nodes), 2))
    senders = np.array([x[0] for x in all_edges])
    receivers = np.array([x[1] for x in all_edges])
    n_edges = len(all_edges)
    edges = np.expand_dims(np.array([0.0]*n_edges, dtype=np.float32), axis=1)
    # print(n_edges, "edges")
    # print("senders:", senders)
    # print("receivers:", receivers)

    input_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([n_nodes], dtype=np.float32)
    }
    target_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([event[solution]], dtype=np.float32)
    }
    input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
    target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
    return [(input_graph, target_graph)]

In [None]:
graphs = make_graph(event)

In [None]:
g_evt_input, g_evt_target = graphs[0]

In [None]:
print_graphs_tuple(g_evt_input, data=False)

In [None]:
17*16//2

In [None]:
g_evt_target.globals

### 2. Graph Neural Network

-----------------------------------
```python

NUM_LAYERS = 2 
def make_mlp_model():
  """Instantiates a new MLP, followed by LayerNorm.

  The parameters of each new MLP are not shared with others generated by
  this function.

  Returns:
    A Sonnet module which contains the MLP and LayerNorm.
  """
  # the activation function choices:
  # swish, relu, relu6, leaky_relu
  return snt.Sequential([
      snt.nets.MLP([128, 64]*NUM_LAYERS,
                    activation=tf.nn.relu,
                    activate_final=True, 
                  #  dropout_rate=DROPOUT_RATE
        ),
      snt.LayerNorm(axis=-1, create_scale=True, create_offset=False)
  ])
```

-----------------------------------

```python
import tensorflow as tf
import sonnet as snt

from graph_nets import utils_tf
from graph_nets import modules
from graph_nets import blocks

from root_gnn.src.models.base import MLPGraphNetwork
from root_gnn.src.models.base import make_mlp_model

LATENT_SIZE = 128

class GlobalClassifierNoEdgeInfo(snt.Module):

    def __init__(self, name="GlobalClassifierNoEdgeInfo"):
        super(GlobalClassifierNoEdgeInfo, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(
            edge_model_fn=make_mlp_model,
            use_edges=False,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block'
        )

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )
        
        self._core = MLPGraphNetwork()
        # Transforms the outputs into appropriate shapes.
        global_output_size = 1
        global_fn =lambda: snt.Sequential([
            snt.nets.MLP([LATENT_SIZE, global_output_size],
                         name='global_output'), tf.sigmoid])

        self._output_transform = modules.GraphIndependent(None, None, global_fn)

    def __call__(self, input_op, num_processing_steps):
        latent = self._global_block(self._edge_block(self._node_encoder_block(input_op)))
        latent0 = latent

        output_ops = []
        for _ in range(num_processing_steps):
            core_input = utils_tf.concat([latent0, latent], axis=1)
            latent = self._core(core_input)
            output_ops.append(self._output_transform(latent))

        return output_ops
```
-----------------------------------

### 3. Training GNN

In [None]:
from root_gnn import model as all_models
import sonnet as snt
from root_gnn import losses

In [None]:
model = all_models.GlobalClassifierNoEdgeInfo()

In [None]:
num_processing_steps_tr = 10
outputs_tr = model(g_evt_input, num_processing_steps_tr)

In [None]:
outputs_tr[-1].globals

In [None]:
g_evt_target.globals

In [None]:
loss_fcn = losses.GlobalLoss(real_global_weight=1., fake_global_weight=1.)

In [None]:
loss_fcn(g_evt_target, outputs_tr)

In [None]:
learning_rate = 0.0005
optimizer = snt.optimizers.Adam(learning_rate)

```python
@functools.partial(tf.function, input_signature=input_signature)
def update_step(inputs_tr, targets_tr):
    print("Tracing update_step")
    with tf.GradientTape() as tape:
        outputs_tr = model(inputs_tr, num_processing_steps_tr)
        loss_ops_tr = loss_fcn(targets_tr, outputs_tr)
        loss_op_tr = tf.math.reduce_sum(loss_ops_tr) / tf.constant(num_processing_steps_tr, dtype=tf.float32)

    gradients = tape.gradient(loss_op_tr, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)
    return outputs_tr, loss_op_tr
```

### 4. Evaluating GNN