##### Copyright 2022 The TensorFlow GNN Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eicther express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Solving OGBN-MAG end-to-end with TF-GNN


<table class="tfo-notebook-buttons" align="left">
  <td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
</table>

### Abstract

[Graph Neural Networks](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/intro.md) (GNNs) are a powerful tool for deep learning on relational data. This tutorial introduces the two main tools required to train GNNs at scale:

1. *Graph Sampler*: The [graph sampler](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/sampler/graph_sampler.py) helps to efficiently sample subgraphs from huge graphs.
2. *The Runner*: Also known as the Orchestrator, [the runner](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/runner.md) orchestrates the end to end training of GNNs with minimal coding. The Runner code is a high-level abstraction for training GNNs models provided by the TensorFlow GNN (TF-GNN) library.

This tutorial is intended for ML practitioners with a basic idea of GNNs.

## Colab set-up

In [None]:
!pip install -q --pre tensorflow-gnn || echo "Ignoring package errors..."

In [None]:
import functools
import itertools
import os
import re
from typing import Mapping

import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.experimental.in_memory import datasets
from tensorflow_gnn.models import vanilla_mpnn
tf.get_logger().setLevel('ERROR')

print(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")

NUM_TRAINING_SAMPLES = 629571
NUM_VALIDATION_SAMPLES = 64879

Running TF-GNN 0.6.0rc0 under TensorFlow 2.12.0.


## Introduction

### Problem statement and dataset

OGBN-MAG is [Open Graph Benchmark](https://ogb.stanford.edu)'s Node classification task on a subset of the [Microsoft Academic Graph](https://www.microsoft.com/en-us/research/publication/microsoft-academic-graph-when-experts-are-not-enough/).

The OGBN-MAG dataset is one big heterogeneous graph. The graph has four sets (or types) of nodes.

  * Node set "paper" contains 736,389 published academic papers, each with a 128-dimensional word2vec feature vector computed by averaging the embeddings of the words in its title and abstract.
  * Node set "field_of_study" contains 59,965 fields of study, with no associated features.
  * Node set "author" contains the 1,134,649 distinct authors of the papers, with no associated features.
  * Node set "institution" contains 8740 institutions listed as affiliations of authors, with no associated features.

The graph has four sets (or types) of directed edges, with no associated features on any of them.

  * Edge set "cites" contains 5,416,217 edges from papers to the papers they cite.
  * Edge set "has_topic" contains 7,505,078 edges from papers to their zero or more fields of study.
  * Edge set "writes" contains 7,145,660 edges from authors to the papers that list them as authors.
  * Edge set "affiliated_with" contains 1,043,998 edges from authors to the zero or more institutions that have been listed as their affiliation(s) on any paper.

The task is to **predict the venue** (journal or conference) at which each of the papers has been published. There are 349 distinct venues, not represented in the graph itself. The benchmark metric is the accuracy of the predicted venue.

Results for this benchmark confirm that the graph structure provides a lot of relevant but "latent" information. Baseline models that only use the one explicit input feature (the word2vec embedding of a paper's title and abstract) perform less well.

OGBN-MAG defines a split of node set "papers" into **train, validation and test nodes**, based on its "year" feature:

  * "train" has the 629,571 papers with `year<=2017`,
  *  "validation" has the 64,879 papers with `year==2018`, and
  * "test" has the 41,939 papers with `year==2019`.

However, under OGB rules, training may happen on the full graph, just restricted to predictions on the "train" nodes. We follow that for consistency in benchmarking. However, users working on their own datasets may wish to validate and test with a more realistic separation between training data from the past and evaluation data representative of future inputs for prediction.

### Approach

OGBN-MAG asks to classify each of the "paper" nodes. The number of nodes is on the order of a million, and we intuit that the most informative other nodes are found just a few hops away (cited papers, papers with overlapping authors, etc.).

Therefore, and to stay scalable for even bigger datasets, we approach this task with **graph sampling**: Each "paper" node becomes one training example, expressed by a subgraph that has the node to be classified as its root and stores a sample of its neighborhood in the original graph. The sample is taken by going out a fixed number of steps along specific edge sets, and randomly downsampling the edges in each step if they are too numerous.

The actual **TensorFlow model** runs on batches of these sampled subgraphs, applies a Graph Neural Network to propagate information from related nodes towards the root node of each batch, and then applies a softmax classifier to predict one of 349 classes (each venue is a class).

The exponential fan-out of graph sampling quickly gets expensive. Sampling and model should be designed together to make the most of the available information in carefully sampled subgraphs.


## Data preparation and graph sampling


### Preparing the graph

We provide the entire OGBN-MAG graph data casted as a TF-GNN graph tensor as input to the graph sampler. TF-GNN comes with a [converter script](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/converters/ogb/convert_ogb_dataset.py) to download and convert OGB datasets into `.npz` files. The command below loads the entire OGBN-MAG as a single graph tensor from the already-saved `.npz` files (subject to this [license](https://storage.googleapis.com/download.tensorflow.org/data/ogbn-mag/npz/LICENSE.txt)). Note, we pass the `add_reverse_edge_sets` to `True` as we will see later that this allows us to sample edge sets in the reverse direction.


In [None]:
input_file = 'gs://download.tensorflow.org/data/ogbn-mag/npz/ogbn-mag.npz'
add_reverse_edge_sets = True
full_ogbn_mag_graph_tensor = datasets.load_ogbn_graph_tensor(input_file, add_reverse_edge_sets=add_reverse_edge_sets)
graph_schema = tfgnn.create_schema_pb_from_graph_spec(full_ogbn_mag_graph_tensor.spec)

### Sampling

As OGBN-MAG dataset as a graph is huge, we sample from the graph to facilitate training on batches of subgraphs.

The sampling we have chosen for OGBN-MAG proceeds as follows:

  1. Start from all "paper" (seed) nodes.
  2. For each paper from 1, follow a random sample of "cites" edges to other "paper" nodes.
  3. For each paper from 1 or 2, follow a random sample of "rev_writes" edges to "author" nodes.
  4. For each author from 3, follow a random sample of "writes" edges to more "paper" nodes.
  5. For each author from 3, follow a random sample of "affiliated_with" edges to "institution" nodes.
  6. For each paper from 1, 2 or 4, follow a random sample of "has_topic" edges to "field_of_study" nodes.

Below, we spell out the above sampling strategy in an easy-to-read python code.

In [None]:
train_sampling_sizes = {
    "cites": 8,
    "rev_writes": 8,
    "writes": 8,
    "affiliated_with": 8,
    "has_topic": 8,
}
validation_sample_sizes = train_sampling_sizes.copy()

def create_sampling_model(
    full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]
) -> tf.keras.Model:

  def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
    edge_set_name = sampling_op.edge_set_name
    sample_size = sizes[edge_set_name]
    return sampler.InMemUniformEdgesSampler.from_graph_tensor(
        full_graph_tensor, edge_set_name, sample_size=sample_size
    )

  def get_features(node_set_name: tfgnn.NodeSetName):
    return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
        full_graph_tensor, node_set_name
    )

  # Spell out the sampling procedure in python
  sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
  seed = sampling_spec_builder.seed("paper")
  papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
  authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
  papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
  institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
  fields_of_study = (seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic"))
  sampling_spec = sampling_spec_builder.build()

  model = sampler.create_sampling_model_from_spec(
      graph_schema, sampling_spec, edge_sampler, get_features,
      seed_node_dtype=tf.int64)

  return model



Notice how our sampler allows sampling edge sets in the reverse direction by setting `add_reverse_edge_sets=True` while loading `full_ogbn_mag_graph_tensor`. The `rev_writes` is the derived from the one edge set `writes` of the original OGBN-MAG graph which goes in the direction from node set `papers` to node set `authors`.

The sampling output contains all nodes and edges traversed by sampling, in their respective node/edge sets and with their associated features. An edge between two sampled nodes that exists in the input graph but has not been traversed by sampling is not included in the sampled output. For example, we get the `cites` edges followed in step 2, but no edges for citations between the papers discovered in step 4.


## Data Split Preparation


Under [OGB rules](https://ogb.stanford.edu/docs/leader_rules/), we can sample subgraphs for the training, validation and test dataset from the full graph, just with different seed nodes, selected by the year of publication. We define the `seed_dataset` responsible for providing the seeds for the different splits. (Models for production systems should probably use separate validation and test data, to prevent leakage of their seed nodes into the sampled subgraphs of other splits.)

In [None]:
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
  """Seed dataset as indices of papers within split years."""
  if split_name == "train":
    mask = years <= 2017  # 629,571 examples
  elif split_name == "validation":
    mask = years == 2018  # 64,879 examples
  elif split_name == "test":
    mask = years == 2019  # 41,939 examples
  else:
    raise ValueError(f"Unknown split_name: '{split_name}'")
  seed_indices = tf.squeeze(tf.where(mask), axis=-1)
  return tf.data.Dataset.from_tensor_slices(seed_indices)

Next, we combine the `seed_dataset` with the sampling model to obtain the `SubgraphDatasetProvider`.

In [None]:
class SubgraphDatasetProvider(runner.DatasetProvider):
  """Dataset Provider based on Sampler V2."""

  def __init__(self,
               full_graph_tensor: tfgnn.GraphTensor,
               sizes: Mapping[str, int],
               split_name: str):
    super().__init__()
    # Extract years of publication of all papers for determining seeds.
    self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
    self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
    self._split_name = split_name
    self.input_graph_spec = self._sampling_model.output.spec

  def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
    """Creates TF dataset."""
    self._seed_dataset = seed_dataset(self._years, self._split_name)
    ds = self._seed_dataset.shard(
        num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
    if self._split_name == "train":
      ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
    # samples 128 subgraphs in parallel. Larger is better, but could cause OOM.
    ds = ds.batch(128)
    ds = ds.map(
        functools.partial(self.sample),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False,
    )
    return ds.unbatch().prefetch(tf.data.AUTOTUNE)

  def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
    seeds = tf.cast(seeds, tf.int64)
    batch_size = tf.size(seeds)
    # samples subgraphs for each seed independently as [[seed1], [seed2], ...]
    seeds_ragged = tf.RaggedTensor.from_row_lengths(
        seeds, tf.ones([batch_size], tf.int64),
    )
    return self._sampling_model(seeds_ragged)

train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()

## Distributed Training



We use TensorFlow's [Distribution Strategy](https://www.tensorflow.org/guide/distributed_training) API to write a model that can run on multiple TPUs, multiple GPUs, or maybe just locally on CPU.



In [None]:
try:
  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  print("Running on TPU ", tpu_resolver.cluster_spec().as_dict()["worker"])
except:
  tpu_resolver = None

if tpu_resolver:
  print("Using TPUStrategy")
  min_nodes_per_component = {"paper": 1}
  strategy = runner.TPUStrategy()
  train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider, min_nodes_per_component)
  valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider, min_nodes_per_component)
elif tf.config.list_physical_devices("GPU"):
  print(f"Using MirroredStrategy for GPUs")
  gpu_list = !nvidia-smi -L
  print("\n".join(gpu_list))
  strategy = tf.distribute.MirroredStrategy()
  train_padding = None
  valid_padding = None
else:
  print(f"Using default strategy")
  strategy = tf.distribute.get_strategy()
  train_padding = None
  valid_padding = None
print(f"Found {strategy.num_replicas_in_sync} replicas in sync")

Running on TPU  ['10.49.184.122:8470']
Using TPUStrategy
Found 8 replicas in sync


As you might have noticed above, we need to provide a padding strategy when we want to train on TPUs. Next, we explain the need for paddings on TPU and the different padding strategies employed during training and validation.

### Padding (for TPUs)


Training on TPUs involves just-in-time compilation of a TensorFlow model to TPU code, and requires *fixed shapes* for all Tensors involved. To achieve that for graph data with variable numbers of nodes and edges, we need to pad each input Tensor to some fixed maximum size. For training on GPUs or CPU, this extra step is not necessary.

#### TightPadding: Padding for the validation dataset

For the validation dataset, we need to make sure that every batch of examples fits within the fixed size, no matter how the parallelism in the input pipeline ends up combining examples into batches. Therefore, we use a rather generous estimate, basically scaling each Tensor's observed maximum size by a factor of `batch_size`. If that were to run into limitations of accelerator memory, we'd rather shrink the batch size than lose examples.

The dataset in this example is not too big, so we can scan it within a few minutes to determine constraints large enough for all inputs. (For huge datasets under your control, it may be worth inferring an upper bound from the sampling spec instead.)


#### FitOrSkipPadding: Padding for the training dataset

For the training dataset, TF-GNN allows you to optimize more aggressively for large batch sizes: size constraints satisfied by 100% of the inputs have to accommodate the rare combination of many large examples in one batch.

Instead, we use size constraints that will fit *close to* 100% of the randomly drawn training batches. This is not covered by the theory supporting stochastic gradient descent (which calls for examples drawn independently at random), but in practice, it often works, and allows larger batch sizes within the limits of accelerator memory, and hence faster convergence of the training.

## Model Building and Training

We build a model on sampled subgraphs that predicts one of 349 classes (venues) for the subgraph's root node. We use a Graph Neural Network (GNN) to propagate information along edge sets towards the subgraph's root node.

Observe how the various node sets play different roles:

  * Node set "paper" has many nodes. It contains the node to predict on. Some of its nodes are linked by "cites" edges, which seem relevant for the prediction task. Its nodes also carry the only input feature besides adjacency, namely the word2vec embedding of title and abstract.
  * Node set "author" also has many nodes. Authors have no features of their own, but having an author in common provides a seemingly relevant relation between papers.
  * Node set "field_of_study" has relatively few nodes. They have no features by themselves, but having a common field of study provides a seemingly relevant relation between papers.
  * Node set "institution" has relatively few nodes. It provides an additional relation on authors.

For node sets "paper" and "author", we follow the standard GNN approach to maintain a hidden state for each node and update it several times with information from the inbound edges. Notice how sampling has equipped each "paper" or "author" adjacent to the root node with a 1-hop neighborhood of its own. Our model does 4 rounds of updates, which covers the longest possible path in a sampled subgraph: a seed paper "cites" a paper that was written by ("rev_writes") an author who "writes" another paper that "has_topic" in some field of study.

For node sets "field_of_study" and "institution", a GNN on the full graph could produce meaningful hidden states for their few elements in the same way. However, in the sampled approach, it seems wasteful to do that from scratch for every subgraph. Instead, our model reads hidden states for them out of an embedding table. This way, the GNN can treat them as read-only nodes with outgoing edges only; the writing happens implicitly by gradient updates to their embeddings. (We choose to maintain a single embedding shared between the rounds of GNN updates.) – Notice how this modeling decision directly influences the sampling spec.

## Process Features

Usually in TensorFlow, the non-trainable transformations of the input features are split off into a `Dataset.map()` call while the main model consists of the trainable and accelerator-compatible parts. However, even this non-trainable part is put into a Keras model, which is a convenient way to track resources (such as lookup tables) for exporting to a SavedModel.

### Feature Preprocessing

Typically, feature preprocessing happens locally on nodes and edges. TF-GNN strives to reuse standard Keras implementations for this.  The `tfgnn.keras.layers.MapFeatures` layer lets you express feature transformations on the graph as a collection of feature transformations for the various graph pieces (node sets, edge sets, and context).

At this stage, the eventual training label is still a feature on the `GraphTensor`. If necessary, it could also be preprocessed (e.g., turn a string-valued class label into a numeric id), but that's not the case here.
The training `Task` (defined below) splits the label out of the `GraphTensor`.

In [None]:
# For nodes
def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
  if node_set_name == "field_of_study":
    return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
  if node_set_name == "institution":
    return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
  if node_set_name == "paper":
    # Keep `labels` for eventual extraction.
    return {"feat": node_set["feat"], "labels": node_set["label"]}
  if node_set_name == "author":
    return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")

# For context and edges, in this example, we drop all features.
def drop_all_features(_, **unused_kwargs):
  return {}

# The combined feature mapping of context, edges and nodes
# is all the preprocessing we need for this dataset.
feature_processors = [
    tfgnn.keras.layers.MapFeatures(context_fn=drop_all_features,
                                   node_sets_fn=process_node_features,
                                   edge_sets_fn=drop_all_features),
]

## Model Architecture

Typically, a model with a GNN architecture at its core consists of three parts:

1. The initialization of hidden states on nodes (and possibly also edges and/or the graph context) from their respective preprocessed features.
2. The base Graph Neural Network: several rounds of updating hidden states from neighboring items in the graph.
3. The readout of one or more hidden states into some prediction head, such as a linear classifier.

We are going to use one model for training, validation, and export for inference, so we need to build it from an input type spec with generic tensor shapes. (For TPUs, it's good enough to use it on a *dataset* with fixed-size elements.) Before defining the base Graph Neural Network, we show how to initialize the hidden states of all the necessary components (nodes, edges and context) given the pre-processed features.

## Initialization of Hidden States

The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. Similarly to feature preprocessing, the `tfgnn.keras.layers.MapFeatures` layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders.

In [None]:
# Hyperparameters
paper_dim = 512

def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
  if node_set_name == "field_of_study":
    return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
  if node_set_name == "institution":
    return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
  if node_set_name == "paper":
    return tf.keras.layers.Dense(paper_dim)(node_set["feat"])
  if node_set_name == "author":
    return node_set["empty_state"]
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")

It is important to understand the distinction between feature pre-processing and hidden state intialization despite the fact that both of the steps are defined using `tfgnn.keras.layers.MapFeatures`. Feature pre-processing step is non-trainable and occurs asynchronous to the training loop. On the other hand, hidden state initialization is trainable and occurs on the corresponding accelerator.

## Base Graph Neural Network

After the hidden states have been initialized, we pass the graph through the base Graph Neural Network, which is a sequence of GraphUpdates. Each GraphUpdate inputs a GraphTensor and returns a GraphTensor with the same graph structure, but the hidden states of nodes have been updated using the information of the neighbor nodes. In our example, the input examples are sampled subgraphs with up to 4 hops, so we perform 4 rounds of graph updates which suffice to bring all information into the root node. Here, we utilize the already-available [`VanillaMPNNGraphUpdate`](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/models/vanilla_mpnn/layers.py) to perform GraphUpdate. TF-GNN offers various modelling choices which are described in [tf-gnn-modeling-guide](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md) and the [tf-gnn-models README](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/models/README.md).


In [None]:
# Hyperparameters
l2_regularization = 6E-6
dropout_rate = 0.1
use_layer_normalization=True

def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
  graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
  graph = tfgnn.keras.layers.MapFeatures(
      node_sets_fn=set_initial_node_states)(graph)
  for _ in range(4):
    graph = vanilla_mpnn.VanillaMPNNGraphUpdate(
        units=128,
        message_dim=128,
        receiver_tag=tfgnn.SOURCE,
        l2_regularization=l2_regularization,
        dropout_rate=dropout_rate,
        use_layer_normalization=use_layer_normalization,
    )(graph)
  return tf.keras.Model(inputs, graph)

An important parameter to assign in the GraphUpdate function is the `receiver_tag`. To determine this tag, it is important to understand the difference between `tfgnn.SOURCE` and `tfgnn.TARGET`. *Source* indictates the node from where an edge originates while *Target* indicates the node to which an edge points to.

The graph sampler starts sampling from the root node (one can think of the root node as the main source of the subgraph) and stores edges in the direction of their discovery while sampling. Given this construct, the GNN needs to send information in the reverse direction towards the root. In other words, the information needs to be propagated towards the `SOURCE` of each edge, so that it can reach and update the hidden state of the root. Thus, we set the `receiver_tag` to be `tfgnn.SOURCE`. An interesting observation arising from the fact that `receiver_tag=tfgnn.SOURCE` is that since the node sets `"field_of_study"` and `"institution"` have no outgoing edge sets, the `VanillaMPNNGraphUpdate` does not change their hidden states: these remain the embedding tables from node state initialization.

## The Task

A Task collects the ancillary pieces for training a Keras model
with the graph learning objective. It also provides losses and metrics for that objective. Common implementations for classification and regression (by graph or root node) are provided in TF-GNN library.


In [None]:
label_fn = runner.RootNodeLabelFn(node_set_name="paper", feature_name="labels")
task = runner.RootNodeMulticlassClassification(
    node_set_name="paper",
    num_classes=349,
    label_fn=label_fn)

## The Trainer

A Trainer provides any training and validation loops. These may be uses of `tf.keras.Model.fit` or arbitrary custom training loops. The Trainer provides accesors to training properties (like its `tf.distribute.Strategy` and model_dir) and is expected to return a trained tf.keras.Model.

In [None]:
# Hyperparameters
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001
if tpu_resolver:
  # Training on TPU takes ~90 secs / epoch, so we train for the entire epoch.
  epoch_divisor = 1
else:
  # Training on GPU / CPU is slower, so we train for 1/100th of a true epoch.
  # Feel free to edit the `epoch_divisor` according to your patience and ambition. ;-)
  epoch_divisor = 100
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size // epoch_divisor
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size // epoch_divisor
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, steps_per_epoch*epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam,
                                  learning_rate=learning_rate)

# Trainer
trainer = runner.KerasTrainer(
    strategy=strategy,
    model_dir="/tmp/gnn_model/",
    callbacks=None,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    restore_best_weights=False,
    checkpoint_every_n_steps="never",
    summarize_every_n_steps="never",
    backup_and_restore=False,
)

## Export options for inference

For inference, a SavedModel must be exported by the runner at the end of training. C++ inference environments like TF Serving do not support input of extension types like GraphTensor, so the `KerasModelExporter` exports the model with a SavedModel Signature that accepts a batch of serialized tf.Examples and preprocesses them like training did.

Note: After connecting this Colab to a TPU worker, explicit device placements are necessary to do the test on the colab host (which has the `/tmp/gnn_model` directory).

In [None]:
save_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
model_exporter = runner.KerasModelExporter(options=save_options)

## Let the Runner do its magic!

Orchestration (a term for the composition, wiring and execution of the above abstractions) happens via a single run method with following signature shown below.

Training for 10 epochs of sampled subgraphs takes a few hours on a free colab with one GPU (T4) and should achieve an accuracy above 0.50. Training with the free Cloud TPU runtime is *much* faster, and completes the entire training within 20 mins.

NOTE: It take ~4 minutes before training starts on TPU to learn optimal TPU padding constraints.

In [None]:
runner.run(
    train_ds_provider=train_ds_provider,
    train_padding=train_padding,
    model_fn=model_fn,
    optimizer_fn=optimizer_fn,
    epochs=epochs,
    trainer=trainer,
    task=task,
    gtspec=example_input_graph_spec,
    global_batch_size=global_batch_size,
    model_exporters=[model_exporter],
    feature_processors=feature_processors,
    valid_ds_provider=valid_ds_provider, # <<< Remove if not training for real.
    valid_padding=valid_padding)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10




RunResult(preprocess_model=<keras.engine.functional.Functional object at 0x7c1d644f4e50>, base_model=<keras.engine.sequential.Sequential object at 0x7c1d64882980>, trained_model=<keras.engine.functional.Functional object at 0x7c1d6457a350>)

## Inference using Exported Model
At the end of training, a SavedModel is exported by the Runner for inference. For demonstration, let's call the exported model on the validation dataset from above, but without labels. We load it as a SavedModel, like TF Serving would. Analogous to the SaveOptions above, LoadOptions with a device placement are necessary when connecting this Colab to a TPU worker.

NOTE: TF Serving usually expects examples in form of serialized strings, therefore we explicitly convert the graph tensors to serialized string format and pass it to the loaded model.





In [None]:
load_options = tf.saved_model.LoadOptions(experimental_io_device="/job:localhost")
saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, "export"),
                                  options=load_options)

def _clean_example_for_serving(graph_tensor):
  serialized_example = tfgnn.write_example(graph_tensor)
  return serialized_example.SerializeToString()

# Convert 10 examples to serialized string format.
num_examples = 10
demo_ds = valid_ds_provider.get_dataset(tf.distribute.InputContext())
serialized_examples = [_clean_example_for_serving(gt) for gt in itertools.islice(demo_ds, num_examples)]

# Inference on 10 examples
ds = tf.data.Dataset.from_tensor_slices(serialized_examples)
kwargs = {"examples": next(iter(ds.batch(10)))}
output = saved_model.signatures["serving_default"](**kwargs)

# Outputs are in the form of logits
logits = next(iter(output.values()))
probabilities = tf.math.softmax(logits).numpy()
classes = probabilities.argmax(axis=1)

# Print the predicted classes
for i, c in enumerate(classes):
  print(f"The predicted class for input {i} is {c:3} "
        f"with predicted probability {probabilities[i, c]:.4}")

The predicted class for input 0 is   9 with predicted probability 0.6633
The predicted class for input 1 is 189 with predicted probability 0.2399
The predicted class for input 2 is 189 with predicted probability 0.3526
The predicted class for input 3 is 158 with predicted probability 0.9894
The predicted class for input 4 is 236 with predicted probability 0.1657
The predicted class for input 5 is 247 with predicted probability 0.8843
The predicted class for input 6 is 209 with predicted probability 0.7814
The predicted class for input 7 is 247 with predicted probability 0.5961
The predicted class for input 8 is 192 with predicted probability 0.2086
The predicted class for input 9 is 311 with predicted probability 0.8125


## Next steps

This tutorial has shown how to solve a node classification problem in a large graph with TF-GNN using
  * the graph sampler tool to obtain manageable-sized inputs for each classification target,
  * the Runner for training GNNs with minimal coding.

The [Data Preparation and Sampling](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md) guide describes how you can create training data for other datasets.

The colab notebook [An in-depth look at TF-GNN](https://colab.research.google.com/github/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_indepth.ipynb) solves OGBN-MAG again, but without the abstractions provided by the Runner and the ready-to-use VanillaMPNN model. Take a look if you like to know more, or want more control in designing GNNs for your own task.

For more complete documentation, please check out the [TF-GNN documentation](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/overview.md).

