##### Copyright 2022 The TensorFlow GNN Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eicther express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Solving OGBN-MAG end-to-end with TF-GNN


<table class="tfo-notebook-buttons" align="left">
  <td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
</table>

### Abstract

This advanced tutorial shows how to build a Graph Neural Network (or GNN) for the popular [OGBN-MAG](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag) benchmark dataset with [TensorFlow](https://www.tensorflow.org) and [Keras](https://keras.io), using the [TF-GNN](https://github.com/tensorflow/gnn) library. It explicitly shows all the main building blocks in action, and assumes that readers have a working knowledge of TF2/Keras already. (For a much shorter example, see the [MUTAG tutorial](https://colab.research.google.com/github/tensorflow/gnn/blob/main/examples/notebooks/intro_mutag_example.ipynb). For a more hands-off approach, see TF-GNN's [Runner](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/runner.md) and its own [OGBN-MAG example](https://github.com/tensorflow/gnn/tree/main/tensorflow_gnn/runner/examples/ogbn/mag).)

The code in this tutorial uses subgraphs sampled from the original OGBN-MAG dataset. The [Data Preparation and Sampling](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md) guide explains how to do that for OGBN-MAG or your own dataset.

## Colab set-up

In [None]:
!pip install -q --pre tensorflow-gnn || echo "Ignoring package errors..."

In [None]:
import functools
import os
import random
import re

from google.protobuf import text_format
import tensorflow as tf
import tensorflow_gnn as tfgnn

print(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")

Running TF-GNN 0.2.0rc0 under TensorFlow 2.8.2.


## Introduction

### Problem statement and dataset

OGBN-MAG is [Open Graph Benchmark](https://ogb.stanford.edu)'s Node classification task on a subset of the [Microsoft Academic Graph](https://www.microsoft.com/en-us/research/publication/microsoft-academic-graph-when-experts-are-not-enough/).

The OGBN-MAG dataset is one big heterogeneous graph. The graph has four sets (or types) of nodes.

  * Node set "paper" contains 736,389 published academic papers, each with a 128-dimensional word2vec feature vector computed by averaging the embeddings of the words in its title and abstract.
  * Node set "field_of_study" contains 59,965 fields of study, with no associated features.
  * Node set "author" contains the 1,134,649 distinct authors of the papers, with no associated features.
  * Node set "institution" contains 8740 institutions listed as affiliations of authors, with no associated features.

The graph has four sets (or types) of directed edges, with no associated features on any of them.

  * Edge set "cites" contains 5,416,217 edges from papers to the papers they cite.
  * Edge set "has_topic" contains 7,505,078 edges from papers to their zero or more fields of study.
  * Edge set "writes" contains 7,145,660 edges from authors to the papers that list them as authors.
  * Edge set "affiliated_with" contains 1,043,998 edges from authors to the zero or more institutions that have been listed as their affiliation(s) on any paper.

The task is to **predict the venue** (journal or conference) at which each of the papers has been published. There are 349 distinct venues, not represented in the graph itself. The benchmark metric is the accuracy of the predicted venue.

Results for this benchmark confirm that the graph structure provides a lot of relevant but "latent" information. Baseline models that only use the one explicit input feature (the word2vec embedding of a paper's title and abstract) perform less well.

OGBN-MAG defines a split of node set "papers" into **train, validation and test nodes**, based on its "year" feature:

  * "train" has the 629,571 papers with `year<=2017`,
  *  "validation" has the 64,879 papers with `year==2018`, and 
  * "test" has the 41,939 papers with `year==2019`.

However, under OGB rules, training may happen on the full graph, just restricted to predictions on the "train" nodes. We follow that for consistency in benchmarking. However, users working on their own datasets may wish to validate and test with a more realistic separation between training data from the past and evaluation data representative of future inputs for prediction.

### Approach

OGBN-MAG asks to classify each of the "paper" nodes. The number of nodes is on the order of a million, and we intuit that the most informative other nodes are found just a few hops away (cited papers, papers with overlapping authors, etc.).

Therefore, and to stay scalable for even bigger datasets, we approach this task with **graph sampling**: Each "paper" node becomes one training example, expressed by a subgraph that has the node to be classified as its root and stores a sample of its neighborhood in the original graph. The sample is taken by going out a fixed number of steps along specific edge sets, and randomly downsampling the edges in each step if they are too numerous.

The actual **TensorFlow model** runs on batches of these sampled subgraphs, applies a Graph Neural Network to propagate information from related nodes towards the root node of each batch, and then applies a softmax classifier to predict one of 349 classes (each venue is a class).

The exponential fan-out of graph sampling quickly gets expensive. Sampling and model should be designed together to make the most of the available information in carefully sampled subgraphs.


## Data preparation and graph sampling


TF-GNN's [Data Preparation and Sampling](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md) guide explains how to sample subgraps from a big input graph with the [graph_sampler](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/sampler/graph_sampler.py) tool, and what its expected input format is. (TF-GNN comes with a [converter script](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/converters/ogb/convert_ogb_dataset.py) from the original OGB format.)

The sampling expected by the model in this colab proceeds as follows:

  1. Start from all "paper" nodes.
  2. For each paper from 1, follow a random sample of "cites" edges to other "paper" nodes.
  3. For each paper from 1 or 2, follow a random sample of reversed "writes" edges to "author" nodes and store them as edge set "written".
  4. For each author from 3, follow a random sample of "writes" edges to more "paper" nodes.
  5. For each author from 3, follow a random sample of "affiliated_with" edges to "institution" nodes.
  6. For each paper from 1, 2 or 4, follow a random sample of "has_topic" edges to "field_of_study" nodes.

The sampling output contains all nodes and edges traversed by sampling, in their respective node/edge sets and with their associated features. An edge between two sampled nodes that exists in the input graph but has not been traversed by sampling is not included in the sampled output. For example, we get the "cites" edges followed in step 2, but no edges for citations between the papers discovered in step 2. Moreover, while edge set "written" is defined in the sampler input as reversal of edge set "writes", the sampler output has different edges in these edge sets (namely those traversed in steps 2 and 3, resp.).

## Reading the sampled subgraphs

The result of sampling is available here (subject to this [license](https://storage.googleapis.com/download.tensorflow.org/data/ogbn-mag/sampled/v1/edge/LICENSE.txt)):


In [None]:
input_file_pattern = "gs://download.tensorflow.org/data/ogbn-mag/sampled/v1/edge/samples-?????-of-00100"
graph_schema_file = "gs://download.tensorflow.org/data/ogbn-mag/sampled/v1/edge/schema.pbtxt"
graph_schema = tfgnn.read_schema(graph_schema_file)

TF-GNN's guide on [Describing your Graph](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/schema.md) explains what a graph schema is. Here is the one for this dataset of sampled subgraphs:

In [None]:
graph_schema

context {
  features {
    key: "sample_id"
    value {
      dtype: DT_STRING
    }
  }
  features {
    key: "seed_id"
    value {
      dtype: DT_STRING
    }
  }
  metadata {
  }
}
node_sets {
  key: "author"
  value {
    features {
      key: "#id"
      value {
        dtype: DT_STRING
      }
    }
    metadata {
    }
  }
}
node_sets {
  key: "field_of_study"
  value {
    features {
      key: "#id"
      value {
        dtype: DT_STRING
      }
    }
    metadata {
    }
  }
}
node_sets {
  key: "institution"
  value {
    features {
      key: "#id"
      value {
        dtype: DT_STRING
      }
    }
    metadata {
    }
  }
}
node_sets {
  key: "paper"
  value {
    features {
      key: "#id"
      value {
        dtype: DT_STRING
      }
    }
    features {
      key: "feat"
      value {
        dtype: DT_FLOAT
        shape {
          dim {
            size: 128
          }
        }
      }
    }
    features {
      key: "labels"
      value {
        dtype: DT_IN

Training a neural network with stochastic gradient descent requires randomly shuffled training data, but ours is too big to fully reshuffle it on the fly while reading. Fortunately, the graph sampler tool has already reshuffled its outputs before writing to a sharded TFRecord file.

For speed, we want to read from several shards in parallel.

For distributed training, each trainer replica reads from its own subset of the sstable shards. To achieve some randomization between training runs, each replica reshuffles the order of input shards and then shuffles examples within a moderate window.


In [None]:
def _get_dataset(file_pattern, *, shuffle=False, filter_fn=None,
                 input_context=None):
  # For your own file system or GCS bucket, call the usual helper
  # filenames = tf.io.gfile.glob(file_pattern)
  # For gs://download.tensorflow.org, we avoid listing it and do
  filenames = _glob_sharded(file_pattern)
  ds = tf.data.Dataset.from_tensor_slices(filenames)
  if input_context and input_context.num_input_pipelines > 1:
    ds = ds.shard(input_context.num_input_pipelines,
                  input_context.input_pipeline_id)
  if shuffle:
    ds = ds.shuffle(len(filenames))

  def interleave_fn(filename):
    ds = tf.data.TFRecordDataset(filename)
    if filter_fn is not None:
      ds = ds.filter(filter_fn)
    return ds
  # TODO(b/234644900): sync with runner/examples on cycle_length.
  ds = ds.interleave(
      interleave_fn, cycle_length=10,
      deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
  if shuffle:
    ds = ds.shuffle(10000)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  return ds

def _glob_sharded(file_pattern):
  match = re.fullmatch(r"(.*)-\?\?\?\?\?-of-(\d\d\d\d\d)", file_pattern)
  if match is None:  # No shard suffix found.
    return [file_pattern]
  basename = match[1]
  n = int(match[2])
  return [f"{basename}-{i:05d}-of-{n:05d}" for i in range(n)]

Graph sampling stores each sampled subgraph in its output as one tf.Example proto, with structured feature names, as explained in the [Data Preparation](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md) guide. If you're curious, take a look here:

In [None]:
demo_ds = _get_dataset(input_file_pattern.replace("-?????-of-00100", "-00000-of-00100"))

In [None]:
def _example_to_text(example):
  """MessageToString() with fewer linebreaks."""
  lines = ["features {"]
  for k, v in sorted(example.features.feature.items()):
    v_str = text_format.MessageToString(v, as_one_line=True,
                                        use_short_repeated_primitives=True)
    lines.append("  feature {")
    lines.append(f"    key: \"{k}\"")
    lines.append(f"    value {{ {v_str} }}")
    lines.append("  }")
  lines.append("}")
  return "\n".join(lines)

for serialized_example in demo_ds.take(1):
  example = tf.train.Example.FromString(serialized_example.numpy())
  print(_example_to_text(example))

features {
  feature {
    key: "context/sample_id"
    value { bytes_list { value: "paper109537" } }
  }
  feature {
    key: "context/seed_id"
    value { bytes_list { value: "paper109537" } }
  }
  feature {
    key: "edges/affiliated_with.#size"
    value { int64_list { value: [6] } }
  }
  feature {
    key: "edges/affiliated_with.#source"
    value { int64_list { value: [0, 1, 2, 3, 4, 4] } }
  }
  feature {
    key: "edges/affiliated_with.#target"
    value { int64_list { value: [0, 0, 1, 0, 1, 2] } }
  }
  feature {
    key: "edges/cites.#size"
    value { int64_list { value: [1] } }
  }
  feature {
    key: "edges/cites.#source"
    value { int64_list { value: [0] } }
  }
  feature {
    key: "edges/cites.#target"
    value { int64_list { value: [20] } }
  }
  feature {
    key: "edges/has_topic.#size"
    value { int64_list { value: [307] } }
  }
  feature {
    key: "edges/has_topic.#source"
    value { int64_list { value: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 

The following helper lets us filter input data by OGB's specific rule for the test/validation/train split before parsing the full GraphTensor. (Users working on their own datasets may want to use separate datasets to begin with.)

In [None]:
def _is_in_split(split_name):
  def filter_fn(serialized_example):
    features = {"years": tf.io.RaggedFeature(value_key="nodes/paper.year",
                                             dtype=tf.int64)}
    years = tf.io.parse_single_example(serialized_example, features)["years"]
    year = years[0]  # By convention, the root node is the first node.
    if split_name == "train":
      return year <= 2017
    elif split_name == "validation":
      return year == 2018
    elif split_name == "test":
      return year == 2019
    else:
      raise ValueError(f"Unknown split_name: '{split_name}'")
  return filter_fn

### The GraphTensor type

The cornerstone of model building with TF-GNN is the `tfgnn.GraphTensor` type. The following code cells demonstrate the essentials of using it. For more information, please see the comprehensive [Introduction to GraphTensor](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/graph_tensor.md).

`tfgnn.GraphTensor` is a TensorFlow Extension Type (or "composite tensor") that consists of multiple Tensors but can be used as one object in a tf.data.Dataset, and in the inputs/outputs of a Keras layer or a tf.function. Other examples of this are `tf.RaggedTensor` and `tf.SparseTensor`.

Every such type has a matching type spec (like a tf.Tensor has a tf.TensorSpec). A `GraphTensorSpec` can be created from the `GraphSchema` and contains information about ots node sets, their connection by edge sets, and the features on the graph. With that, we can parse `tf.Example`s as above into `GraphTensor` values:

In [None]:
example_input_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)
parse = functools.partial(tfgnn.parse_single_example, example_input_spec)
graph = parse(serialized_example)

The `GraphTensor` is an immutable container of tensors, indexed by names. For example, here are the features and labels of all papers in this subgraph:

In [None]:
graph.node_sets["paper"]["labels"]

<tf.Tensor: shape=(30, 1), dtype=int64, numpy=
array([[193],
       [265],
       [ 56],
       [193],
       [277],
       [277],
       [283],
       [ 83],
       [236],
       [193],
       [265],
       [193],
       [236],
       [144],
       [236],
       [ 97],
       [265],
       [265],
       [258],
       [258],
       [262],
       [258],
       [236],
       [236],
       [198],
       [145],
       [236],
       [265],
       [193],
       [277]])>

By convention, the root node of the sampled subgraph is stored as the first node (index 0) of its node set, so the target label for this subgraph is the first item in the tensor above.

A tf.data.Dataset of GraphTensors can freely be batched (and unbatched), which simply stacks (or unstacks) all the tensors that make up the individual graphs. Let's take edge set "cites" as an example.

In [None]:
for batched_graph in demo_ds.map(parse).batch(3).take(1):
  print(batched_graph.edge_sets["cites"].adjacency.source)
  print(batched_graph.edge_sets["cites"].adjacency.target)

<tf.RaggedTensor [[0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]>
<tf.RaggedTensor [[20],
 [11, 90, 144, 167, 181, 196, 199, 201, 228, 264, 267, 287, 376, 401],
 [32, 88, 111, 175, 211, 235, 301, 309, 311, 350, 377, 521, 631, 688, 693,
  723, 760, 770, 788, 795, 940, 1100, 1127]                               ]>


The resulting RaggedTensors and per-graph indices are inconvenient for the modeling code. Hence GraphTensor allows to merge a batch of graphs into a single graph with contiguous indexing. The distinction between the inputs is preserved by the notion of **components** inside the merged graph: all node sets and edge sets keep track of the size of each component. (This is always true, there just happens to be a single component in each graph before merging.)

In [None]:
merged_graph = batched_graph.merge_batch_to_components()
print(merged_graph.edge_sets["cites"].sizes)
print(merged_graph.edge_sets["cites"].adjacency.source)
print(merged_graph.edge_sets["cites"].adjacency.target)
print(merged_graph.node_sets["paper"].sizes)

tf.Tensor([ 1 14 23], shape=(3,), dtype=int32)
tf.Tensor(
[  0  30  30  30  30  30  30  30  30  30  30  30  30  30  30 503 503 503
 503 503 503 503 503 503 503 503 503 503 503 503 503 503 503 503 503 503
 503 503], shape=(38,), dtype=int32)
tf.Tensor(
[  20   41  120  174  197  211  226  229  231  258  294  297  317  406
  431  535  591  614  678  714  738  804  812  814  853  880 1024 1134
 1191 1196 1226 1263 1273 1291 1298 1443 1603 1630], shape=(38,), dtype=int32)
tf.Tensor([  30  473 1308], shape=(3,), dtype=int32)


Features on the node sets or edge sets of the merged graph have the set's total size as their first dimension, followed by the dimensions of individual feature values. This makes them compatible with many standard layers like tf.keras.layers.Dense, which accept an unknown batch size as the first dimension, except the "batch size" of a node feature is the total number of nodes in the batch.


In [None]:
print(merged_graph.node_sets["paper"].total_size.numpy())
print(merged_graph.node_sets["paper"]["feat"].shape)

1811
(1811, 128)


**Please remember:** Each TF-GNN Model needs to call `.merge_batch_to_components()` at one point after the final input batches for each model replica have been formed but before the actual GNN model starts. For TPUs, this and the subsequent padding to fixed sizes has to happen before data is fed into the trained Model.

## Model building and training

We use TensorFlow's [Distribution Strategy](https://www.tensorflow.org/guide/distributed_training) API to write a model that can train in parallel on multiple [Cloud TPUs](https://cloud.google.com/tpu), multiple GPUs, or maybe just locally on CPU. (This is needed on Colab to use Cloud TPUs. This is not required to use the single GPU on a Colab, but we might as well show how it's done for the general case.)

In [None]:
try:
  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  print("Running on TPU ", tpu_resolver.cluster_spec().as_dict()["worker"])
except ValueError:
  tpu_resolver = None
if tpu_resolver:
  print("Using TPUStrategy")
  tf.config.experimental_connect_to_cluster(tpu_resolver)
  tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
  strategy = tf.distribute.TPUStrategy(tpu_resolver)
  assert isinstance(strategy, tf.distribute.TPUStrategy)
elif tf.config.list_physical_devices("GPU"):
  gpu_list = !nvidia-smi -L
  print("\n".join(gpu_list))
  print(f"Using MirroredStrategy for GPUs")
  strategy = tf.distribute.MirroredStrategy()
else:
  strategy = tf.distribute.get_strategy()
  print(f"Using default strategy")
print(f"Found {strategy.num_replicas_in_sync} replicas in sync")

GPU 0: Tesla T4 (UUID: GPU-8626bec7-bcaa-d0d6-5940-d0effe718db0)
Using MirroredStrategy for GPUs
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Found 1 replicas in sync


### Padding (for TPUs)


Training on Cloud TPUs involves just-in-time compilation of a TensorFlow model to TPU code, and requires fixed shapes for all Tensors involved. To achieve that for graph data with variable numbers of nodes and edges, we need to pad each input Tensor to some fixed maximum size.

For training on GPUs or CPU, this extra step is not necessary.

In [None]:
#@title Pad batched GraphTensors to fixed sizes?
#@markdown By default (`None`), padding is used for TPUs only.
use_padding = None #@param ["None", "True", "False"] {type:"raw"}
if use_padding is None:
  use_padding = isinstance(strategy, tf.distribute.TPUStrategy)
print("Padding is", ["OFF", "ON"][use_padding])
if isinstance(strategy, tf.distribute.TPUStrategy) and not use_padding:
  raise ValueError("Padding is required for running on TPU")

Padding is OFF


For the validation dataset, we need to make sure that every batch of examples fits within the fixed sizes, no matter how the parallelism in the input pipeline ends up combining examples into batches. Therefore, we use a rather generous estimate, basically scaling each Tensor's observed maximum size by a factor of batch_size. If that were to run into limitations of accelerator memory, we'd rather shrink the batch size than lose examples.

The dataset in this example is not too big, so we can scan it within a few minutes to determine constraints large enough for all inputs. (For huge datasets under your control, it may be worth inferring an upper bound from the sampling spec instead.)

In [None]:
validation_global_batch_size = 32
assert validation_global_batch_size % strategy.num_replicas_in_sync == 0, "divisibility required"
validation_replica_batch_size = validation_global_batch_size // strategy.num_replicas_in_sync
print(f"Validation uses a global batch size of {validation_global_batch_size} "
      f"({validation_replica_batch_size} per replica).")
validation_size_constraints = None
if use_padding:
  # The "paper" node set needs at least one node in each graph component,
  # incl. those added for padding, because the model will read out the state
  # of the sampled subgraph's root node from each component.
  min_nodes_per_component = {"paper": 1}
  validation_size_constraints = tfgnn.find_tight_size_constraints(
      _get_dataset(input_file_pattern, shuffle=False,
                  filter_fn=_is_in_split("validation"),  # For OGB only.
      ).map(parse),
      target_batch_size=validation_replica_batch_size,
      min_nodes_per_component=min_nodes_per_component)
  print(f"Validation data is padded to: {validation_size_constraints}")

Validation uses a global batch size of 32 (32 per replica).


For the training dataset, TF-GNN allows you to optimize more aggressively for large batch sizes: size constraints satisfied by 100% of all possible inputs would have to accommodate the rare combination of many large examples in one batch.

Instead, we use size constraints that will fit *close to* 100% of the randomly drawn training batches. This is not covered by the theory supporting stochastic gradient descent (which calls for examples drawn independently at random), but in practice, it often works, and allows larger batch sizes within the limits of accelerator memory, and hence faster convergence of the training. Running the code block below with `use_padding=True` shows the sizes required for some success ratios. (Compare them to the constraints above, which are satisfied 100% at a fourth of the batch size.)


In [None]:
training_global_batch_size = 128
assert training_global_batch_size % strategy.num_replicas_in_sync == 0, "divisibility required"
training_replica_batch_size = training_global_batch_size // strategy.num_replicas_in_sync
print(f"Training uses a batch size of {training_global_batch_size} "
      f"({training_replica_batch_size} per replica).")
training_size_constraints = None
if use_padding:
  success_ratios = [0.90, 0.98, 0.99]
  constraints = tfgnn.learn_fit_or_skip_size_constraints(
      _get_dataset(input_file_pattern,
                   filter_fn=_is_in_split("train"),  # For OGB only.
      ).map(parse),
      training_replica_batch_size,
      min_nodes_per_component=min_nodes_per_component,
      success_ratio=success_ratios, sample_size=20000)
  for sr_idx, sr in enumerate(success_ratios):
    print(f"Success ratio {sr} requires: {constraints[sr_idx]}")
  sr_idx = 2
  training_size_constraints = constraints[sr_idx]
  print(f"\nSelected success ratio: {success_ratios[sr_idx]}.")
  print(f"Training data is padded to: {training_size_constraints}")  

Training uses a batch size of 128 (128 per replica).


### The model architecture

We build a model on sampled subgraphs that predicts one of 349 classes (venues) for the subgraph's root node. We use a Graph Neural Network (GNN) to propagate information along edge sets towards the subgraph's root node. GNNs are an active field of academic research. As an introduction, we recommend the textbook by ([Hamilton, 2020](https://www.cs.mcgill.ca/~wlh/grl_book/)).

Observe how the various node sets play different roles:

  * Node set "paper" has many nodes. It contains the node to predict on. Some of its nodes are linked by "cites" edges, which seem relevant for the prediction task. Its nodes also carry the only input feature besides adjacency, namely the word2vec embedding of title and abstract.
  * Node set "author" also has many nodes. Authors have no features of their own, but having an author in common provides a seemingly relevant relation between papers.
  * Node set "field_of_study" has relatively few nodes. They have no features by themselves, but having a common field of study provides a seemingly relevant relation between papers.
  * Node set "institution" has relatively few nodes. It provides an additional relation on authors.

For node sets "paper" and "author", we follow the standard GNN approach to maintain a hidden state for each node and update it several times with information from the inbound edges. Notice how sampling has equipped each "paper" or "author" adjacent to the root node with a 1-hop neighborhood of its own. Our model does 4 rounds of updates, which covers the longest possible path in a sampled subgraph: a seed paper "cites" a paper that was "written" by an author who "writes" another paper that "has_topic" in some field of study.

For node sets "field_of_study" and "institution", a GNN on the full graph could produce meaningful hidden states for their few elements in the same way. However, in the sampled approach, it seems wasteful to do that from scratch for every subgraph. Instead, our model reads hidden states for them out of an embedding table. This way, the GNN can treat them as read-only nodes with outgoing edges only; the writing happens implicitly by gradient updates to their embeddings. (We choose to maintain a single embedding shared between the rounds of GNN updates.) – It's because of this modeling decision that subgraph sampling can stop at the node sets backed by an embedding.


### Preprocessing the input features

As usual in TensorFlow, the non-trainable transformations of the input features are split off into a `Dataset.map()` call while the model proper consists of the trainable and accelerator-compatible parts. However, even this non-trainable part is put into a Keras model, which is a convenient way to track resources for export (such as lookup tables). The following code cells provide a fully worked example; for more background information, please see our [Input pipeline](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/input_pipeline.md) guide.

In [None]:
# Callbacks used by tfgnn.keras.layers.MapFeatures below.

def _preprocess_node_features(node_set, *, node_set_name):
  if node_set_name == "paper":
    return {
        # Retain the word2vec embedding unchanged.
        "feat": node_set["feat"],
        # Keep the label, until popped later on.
        "labels": node_set["labels"]}
  elif node_set_name == "author":
    # There are no useful features. Instead, we create a tensor of hidden node
    # states that are empty, i.e., with shape [batch_size, (num_nodes), 0].
    return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
  elif node_set_name == "field_of_study":
    # Convert the string id to an index into an embedding table.
    # Conveniently, this Keras layer can handle RaggedTensors.
    return {"hashed_id": tf.keras.layers.Hashing(num_bins=50_000)(
        node_set["#id"])}
  elif node_set_name == "institution":
    # Convert the string id to an index into an embedding table.
    return {"hashed_id": tf.keras.layers.Hashing(num_bins=6_500)(
        node_set["#id"])}
  else:
    raise KeyError(f"Unexpected node_set_name='{node_set_name}'")

def _drop_all_features(graph_piece, **unused_kwargs):
  return {}

In [None]:
def _make_preprocessing_model(graph_tensor_spec, size_constraints):
  """Returns Keras model to preprocess a batched and parsed GraphTensor."""
  graph = input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)

  # Convert input features to suitable representations for use on GPU/TPU.
  # Drop unused features (like id strings for tracking the source of examples).
  graph = tfgnn.keras.layers.MapFeatures(
      node_sets_fn=_preprocess_node_features,
      edge_sets_fn=_drop_all_features,
      context_fn=_drop_all_features)(graph)
  assert "labels" in graph.node_sets["paper"].features

  ### IMPORTANT: All TF-GNN modeling code assumes a GraphTensor of shape []
  ### in which the graphs of the input batch have been merged to components of
  ### one contiguously indexed graph. There are no edges between components,
  ### so no information flows between them.
  graph = graph.merge_batch_to_components()

  # Optionally, pad to size_constraints (required for TPU).
  if size_constraints:
    graph, mask = tfgnn.keras.layers.PadToTotalSizes(size_constraints)(graph)
  else:
    mask = None

  # Split the label off the padded input graph.
  root_label = tfgnn.keras.layers.ReadoutFirstNode(
      node_set_name="paper", feature_name="labels")(graph)
  graph = graph.remove_features(node_sets={"paper": ["labels"]})
  assert "labels" not in graph.node_sets["paper"].features

  outputs = (graph, root_label) if mask is None else (graph, root_label, mask)
  return tf.keras.Model(input_graph, outputs)

The preprocessing model merges its batch of input graphs into one contiguously indexed graph with multiple components (as explained above), and then pads it (if applicable). After that, the dataset can no longer be rebatched to split it up between replicas, so we use `tf.distribute.Strategy.distribute_datasets_from_function()` to build a `DistributedDataset` of per-replica inputs. (If it weren't for TPUs or distribution strategies, the call to `merge_batch_to_components()` could be deferred to the start of the trained model, and we could leave it to `Model.fit()` to split up a single Dataset.)

In [None]:
example_input_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)

def _get_preprocessed_dataset(
    input_context, split_name, per_replica_batch_size, size_constraints):
  training = split_name == "train"
  ds = _get_dataset(input_file_pattern, shuffle=training,
                    filter_fn=_is_in_split(split_name),  # For OGB only.
                    input_context=input_context)
  if training:
    ds = ds.repeat()
  # There is no need to drop_remainder when batching, even for TPU:
  # padding the GraphTensor can handle variable numbers of components.
  ds = ds.batch(per_replica_batch_size)
  ds = ds.map(tfgnn.keras.layers.ParseExample(example_input_spec))
  if training and size_constraints:
    ds = ds.filter(functools.partial(tfgnn.satisfies_size_constraints,
                                     total_sizes=size_constraints))
  ds = ds.map(_make_preprocessing_model(ds.element_spec, size_constraints),
              deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
  return ds

def _get_distributed_preprocessed_dataset(
    strategy, split_name, per_replica_batch_size, size_constraints):
  """Returns DistributedDataset and its per-replica element_spec."""
  return strategy.distribute_datasets_from_function(functools.partial(
      _get_preprocessed_dataset,
      split_name=split_name, per_replica_batch_size=per_replica_batch_size,
      size_constraints=size_constraints))

train_ds = _get_distributed_preprocessed_dataset(
    strategy, "train",
    training_replica_batch_size, training_size_constraints)
valid_ds = _get_distributed_preprocessed_dataset(
    strategy, "validation",
    validation_replica_batch_size, validation_size_constraints)



### The GNN core of the model



We are going to use the same Model object for training, validation, and export for inference, so we need to build it from an input type spec with generic tensor shapes. (For TPUs, using it on a *dataset* with fixed-size elements will suffice.) Rather than spelling out the type spec by hand, we create a non-distributed, non-padded dummy dataset whose element spec reflects the preprocessing. (That's cheap, as long as we don't create an iterator on it.)

In [None]:
build_model_graph_tensor_spec, *_ = _get_preprocessed_dataset(
    input_context=None, split_name="train",
    per_replica_batch_size=2, size_constraints=None).element_spec

In [None]:
def _build_model(
    # To be called with the build_model_graph_tensor_spec from above.
    graph_tensor_spec,
    # Dimensions of initial states.
    field_of_study_dim=32,
    institution_dim=16,
    paper_dim=512,
    # Dimensions for message passing.
    message_dim=128,
    next_state_dim=128,
    # Dimension for the logits.
    num_classes=349,
    # Other hyperparameters.
    l2_regularization=5e-4,
    dropout_rate=0.1,
):
  # Model building with Keras's Functional API starts with an input object
  # (a placeholder for future inputs). This works for composite tensors, too.
  graph = input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)

  # The initial hidden states for each node set.
  def set_initial_node_state(node_set, *, node_set_name):
    if node_set_name == "paper":
      # A trainable transformation of the word2vec input features.
      return tf.keras.layers.Dense(paper_dim)(node_set["feat"])
    elif node_set_name == "author":
      # The empty initial state for each node was created in preprocessing
      # and now comes out here with the correct shape (fixed in case of TPU).
      return node_set["empty_state"]
    elif node_set_name == "field_of_study":
      # A trainable embedding (as discussed above).
      return tf.keras.layers.Embedding(50_000, field_of_study_dim)(
          node_set["hashed_id"])
    elif node_set_name == "institution":
      # A trainable embedding (as discussed above).
      return tf.keras.layers.Embedding(6_500, institution_dim)(
          node_set["hashed_id"])
    else:
      raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
  graph = tfgnn.keras.layers.MapFeatures(
      node_sets_fn=set_initial_node_state, name="init_states")(graph)

  # Abbreviations for repeated building blocks in the GNN.
  def dense(units, activation="relu"):
    """A Dense layer with regularization (L2 and Dropout)."""
    regularizer = tf.keras.regularizers.l2(l2_regularization)
    return tf.keras.Sequential([
        tf.keras.layers.Dense(units,
                              activation=activation,
                              kernel_regularizer=regularizer,
                              bias_regularizer=regularizer),
        tf.keras.layers.Dropout(dropout_rate)])

  def convolution(message_dim, receiver_tag):
    return tfgnn.keras.layers.SimpleConv(dense(message_dim), "sum",
                                         receiver_tag=receiver_tag)

  def next_state(next_state_dim):
    return tfgnn.keras.layers.NextStateFromConcat(dense(next_state_dim))

  # The GNN "core" of the model.
  # Convolutions let data flow towards the specified endpoint of edges, e.g.,
  # along "cites" edges from TARGET (cited paper) to SOURCE (citing paper).
  # See the text below the colab cell for more explanations.
  for i in range(4):
    graph = tfgnn.keras.layers.GraphUpdate(node_sets={
        "paper": tfgnn.keras.layers.NodeSetUpdate(
            {"cites": convolution(message_dim, tfgnn.SOURCE),
             "written": convolution(message_dim, tfgnn.SOURCE),
             "has_topic": convolution(message_dim, tfgnn.SOURCE)},
            next_state(next_state_dim)),
         "author": tfgnn.keras.layers.NodeSetUpdate(
            {"writes": convolution(message_dim, tfgnn.SOURCE),
             "affiliated_with": convolution(message_dim, tfgnn.SOURCE)},
            next_state(next_state_dim)),
         })(graph)

  # Read out the hidden state of the root node of each **component** in the
  # graph (cf. .merge_batch_to_components() above).
  root_states = tfgnn.keras.layers.ReadoutFirstNode(node_set_name="paper")(graph)
  # Put a linear classifier on top. (Never use dropout here.)
  logits = tf.keras.layers.Dense(num_classes)(root_states)

  return tf.keras.Model(input_graph, logits)

Some explanations are in order. (You can already start model training in code blocks below while reading here.)

At its core, the model above consists of four rounds of graph updates, each expressed in Keras as

```
graph = GraphUpdate(...)(graph)
```

with a `graph` of type `GraphTensor`.

A graph update consists of updates to node sets, that is to say, replacing the `"hidden_state"` (` == tfgnn.HIDDEN_STATE`) feature on each of them. The node set updates of one graph update happen in parallel, that is, on the same input `graph`.

Each node set update is expressed as a `NodeSetUpdate` layer, which receives the input `graph` and returns a new hidden state for the node set it gets applied to. The new hidden state is computed with the given next-state layer from the node set's prior state and the aggregated results from each incoming edge set.

For example, each round of the model above computes a new state for node set "paper" by applying `dense(next_state_dim)` to the concatenation of

  * the prior state `graph.node_sets["paper"][tfgnn.HIDDEN_STATE]`,
  * the result of `convolution(message_dim)(graph, edge_set_name="cites")`,
  * the result of `convolution(message_dim)(graph, edge_set_name="written")` and
  * the result of `convolution(message_dim)(graph, edge_set_name="has_topic")`.

A convolution on an edge set computes a value for each edge (a "message") as a trainable function of the node states at both endpoints, and then aggregates the results at the receiver nodes by forming the sum (or mean or max) over all incoming edges.

For example, the convolution on edge set "written" concatenates the node state of each edge's incident "paper" and "author" node, applies `dense(message_dim)`, and sums the results over the edges incident to each "paper" node (that is, at the `SOURCE` node of each edge).

Notice that the conventional names *source* and *target* for the endpoints of a directed edge do **not** prescribe the direction of information flow: each "written" edge logically goes from a paper to its author (so the "author" node is its `TARGET`), yet this model lets the data flow towards the paper (and the "paper" node is its `SOURCE`). In fact, sampled subgraphs have edges directed away from the root node, so data flow towards the root often goes from `TARGET` to `SOURCE`.

> Note on terminology: Convolutions are best known in deep learning for convolutional neural networks on image data, in which they aggregate information from a small, fixed, implicitly understood neighborhood of each element in a pixel grid. The term loosely carries over to graphs by interpreting edges as explicit, variable definitions of a node's neighborhood.

The code above creates fresh Convolution and NextState layer objects for each edge set and node set, resp., and for each round of updates. This means they all have separate trainable weights. If desired, weight sharing is possible in the standard Keras way by sharing convolution and next-state layer objects, provided the input sizes match.

For more information on defining your own GNN models (incl. those with edge and context states), please see the [TF-GNN Modeling Guide](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md).

For a collection of pre-defined models, see [`tensorflow_gnn/models/README.md`](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/models/README.md). The `GraphUpdate` spelled out explicitly above is roughly equivalent to a `VanillaMPNNGraphUpdate(node_set_names["paper", "author"])` [[doc](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/VanillaMPNNGraphUpdate.md)].

### Training the model

To train the Keras Model, as usual, we build it under the distribution strategy scope, compile and fit it.


In [None]:
with strategy.scope():
  model = _build_model(build_model_graph_tensor_spec)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  metrics = [tf.keras.metrics.SparseCategoricalAccuracy(),
             tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True)]

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

**IMPORTANT:** In `model.compile()`, pass `weighted_metrics=[...]` instead of plain `metrics=[...]` if your code ever gets used with GraphTensors that are padded to fixed sizes. Otherwise the padding mask is not applied in metric computations, so that padding values interfere with the metrics (making them worse than they really are).

In [None]:
model.compile(tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=loss, weighted_metrics=metrics, steps_per_execution=20)

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Training for 5 epochs of sampled subgraphs takes a few hours on a colab with one GPU and should achieve an accuracy above 0.44. Lucky runs with many more epochs can reach 0.47. (Training with a Cloud TPU runtime is faster, but note that this set-up is not optimized specifically for TPUs.)

To keep this demo more interactive, we let Keras train and evaluate on fractions of a true epoch for a few minutes only. Of course the resulting accuracy will be poor. To fix that, feel free to edit the `epoch_divisor` according to your patience and ambition. ;-)

In [None]:
epoch_divisor = 100  # To speed up the interactive demo_ds

history = model.fit(
    train_ds,
    steps_per_epoch=629571 // training_global_batch_size // epoch_divisor,
    epochs=5,
    validation_data=valid_ds,
    validation_steps=64879 // validation_global_batch_size  // epoch_divisor,
)

Epoch 1/5


  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." %

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Export for inference

At the end of training, a SavedModel is exported for inference. C++ inference environments like TF Serving do not support input of extension types like GraphTensor, so we export the model with a SavedModel Signature that accepts a batch of serialized tf.Examples and preprocesses them like training did.

In [None]:
serving_input = tf.keras.layers.Input(shape=[], dtype=tf.string, name="examples")
preproc_input = tfgnn.keras.layers.ParseExample(example_input_spec)(serving_input)
preproc_model = _make_preprocessing_model(preproc_input.type_spec,
                                          size_constraints=None)
model_input, _ = preproc_model(preproc_input)  # Drop labels. (No mask.)
logits = model(model_input)
serving_output = {
    # SavedModel signature outputs are keyed by the name of the last layer.
    # Restored Keras model outputs preserve the dict seen here.
    # This code puts the same key into both places.
    "logits": tf.keras.layers.Layer(name="logits")(logits),
    "probabilities": tf.keras.layers.Layer(name="probabilities")(
        tf.math.softmax(logits))}
serving_model = tf.keras.Model(serving_input, serving_output)

In [None]:
export_path = "/tmp/exported_keras_model"
!rm -r {export_path}
# Save everything on the Colab host (even the variables from TPU memory).
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
serving_model.save(export_path, include_optimizer=False, options=save_options)



  "imported and registered." % type_spec_class_name)
  "imported and registered." % type_spec_class_name)
  "imported and registered." % type_spec_class_name)
  "imported and registered." % type_spec_class_name)
  "imported and registered." % type_spec_class_name)


INFO:tensorflow:Assets written to: /tmp/exported_keras_model/assets


INFO:tensorflow:Assets written to: /tmp/exported_keras_model/assets


Note that any resources for preprocessing (like lookup tables) have to be attached to the saved object. Our current code has none yet, but composing Keras models as above will take care of it.



For demonstration, let's call the exported model on the example dataset from above, but without labels. We load it as a SavedModel, like TF Serving would. (Using `tf.keras.models.load_model()` instead would rebuild the original Keras layers; see TensorFlow's [Save and load models](https://www.tensorflow.org/tutorials/keras/save_and_load) tutorial for more.)

Note: After connecting this Colab to a TPU worker, explicit device placements are necessary to do the test on the colab host (which has the `/tmp` directory). You can omit those when loading the SavedModel elsewhere.

In [None]:
with tf.device('/job:localhost'):
  restored_model = tf.saved_model.load(export_path)

In [None]:
def _clean_example_for_serving(serialized):
  example = tf.train.Example.FromString(serialized)
  example.features.feature.pop('nodes/paper.labels')
  return example.SerializeToString()

num_examples = 10

with tf.device('/job:localhost'):
  clean_ds = tf.data.Dataset.from_tensor_slices(tf.constant(
      [_clean_example_for_serving(it.numpy()) for it in demo_ds.take(num_examples)],
      dtype=tf.string))

  for serialized_example in clean_ds.batch(num_examples).take(1):
    outputs = restored_model.signatures["serving_default"](
        examples=serialized_example)
    probabilities = outputs["probabilities"].numpy()
    classes = probabilities.argmax(axis=1)
    for i, c in enumerate(classes):
      print(f"The predicted class for input {i} is {c:3} "
            f"with predicted probability {probabilities[i, c]:.4}")

The predicted class for input 0 is 134 with predicted probability 0.00384
The predicted class for input 1 is 134 with predicted probability 0.00384
The predicted class for input 2 is 134 with predicted probability 0.00384
The predicted class for input 3 is 134 with predicted probability 0.00384
The predicted class for input 4 is 134 with predicted probability 0.00384
The predicted class for input 5 is 134 with predicted probability 0.00384
The predicted class for input 6 is 134 with predicted probability 0.00384
The predicted class for input 7 is 134 with predicted probability 0.00384
The predicted class for input 8 is 134 with predicted probability 0.00384
The predicted class for input 9 is 134 with predicted probability 0.00384


Recall that this is not a fully trained model, so the results will be inaccurate, unless you changed `epoch_divisor` above.

## Next steps

This tutorial has shown how to solve a node classification problem in a large graph with TF-GNN, using
  * the graph sampler tool to obtain manageable inputs for each classification target,
  * a TensorFlow model built in Colab with tfgnn.keras.layers.

The [Data Preparation and Sampling](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md) guide describes how you can create training data for other datasets.

The [Modeling](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md) guide explains how to use other GNN architectures or write your own.

For an OGBN-MAG example outside colab, with many technicalities abstracted away and closer to what production code might look like, see the [OGBN-MAG example](https://github.com/tensorflow/gnn/tree/main/tensorflow_gnn/runner/examples/ogbn/mag) written with the TF-GNN [Runner](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/runner.md).