##### Copyright 2022 The TensorFlow GNN Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Node Classification via TF-GNN
This is a hands-on tutorial that explains how to run node classification over Graph Datasets. In particular, it covers:

* How to load into TF-GNN standard academic graph datasets (e.g., OGBN, Cora, Pubmed, Citeseer).
* Shows alternative to either train on full-graph, or using on-the-fly sampling.
* Shows how to choose among simple models.


<table class="tfo-notebook-buttons" align="left">
  <td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/tutorials/log_2022/code_tutorial_1_tfgnn_single_machine.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/gnn/blob/main/examples/tutorials/log_2022/code_tutorial_1_tfgnn_single_machine.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
</table>

## pip-install TF-GNN

In [None]:
!pip install ogb
!pip install tensorflow_gnn --pre

## Imports

In [None]:
# Standard imports
import tensorflow as tf
import tensorflow_gnn as tfgnn
import functools
import collections
from typing import Mapping, List
import functools
import math

In addition to standard imports, this tutorial also needs these imports.

In [None]:
# To use standard datasets that fit in memory.
from tensorflow_gnn.experimental.in_memory import datasets

# Implementations of example GNN models.
from tensorflow_gnn.experimental.in_memory import models

# Converts `tfgnn.GraphTensor` to (`tfgnn.GraphTensor`, `tf.Tensor`)
# with second item containing task labels.
from tensorflow_gnn.experimental.in_memory import reader_utils

# For on-the-fly subgraph sampling.
from tensorflow_gnn.sampler import sampling_spec_builder

## Configure dataset, model hyperparameters, and training loop.

In [None]:
#@title Configurations
#@markdown Dataset
sampling_style = 'inmemory'  #@param ['nosampling','inmemory']
dataset_name = 'ogbn-arxiv'  #@param ['ogbn-arxiv', 'ogbn-mag', 'cora', 'citeseer', 'pubmed']
sampling_num_seed_nodes = 100  #@param int  # no-op if sampling_style == 'nosampling'
# If set, validation set will be used as part of training set. This is allowed
# by some datasets, such as, OGB, *after* finalizing hyperparameters.
valid_as_train = True  #@param

#@markdown Model

# Models are implemented (by name) in examples/models.py
model_name = 'GCN'  #@param ['GCN', 'Simple', 'JKNet', 'GATv2', 'ResidualGCN']

# Passed to the model function (or constructor). Each model could accept
# different parameters.
# All models accept "depth" (i.e., number of GNN layers).
model_kwargs = {"dropout":0.2, "depth":2} #@param

# Converts graph to undirected.
undirect_graph = True #@param

# Only activated when nodes have no feature vectors.
# Every node will be embedded in this many dimensions.
embed_featureless_node_dim = 32  #@param int  # For node sets without 'feat' attribute, they will be embedded.


#@markdown Training loop

# Number of training epochs.
epochs = 50  #@param int

# Every this many epochs, evaluation is computed.
eval_every = 1  #@param int

# Learning rate for ADAM optimizer.
learning_rate = 1e-2 #@param float

# L2 regularization only applies to kernel weights (not biases).
l2_regularization = 1e-5

print(f'training {model_name} (with sampling={sampling_style}, num_seed_nodes={sampling_num_seed_nodes}) on dataset {dataset_name}')

## Dataset

The cell loads the graph data as a tf.data.Dataset, so that it can be passed into the model for training. The following variables are created:

* `train_ds`, `valid_ds`, `test_ds`, all are `tf.data.Dataset` instances. If training on full-graph, then the datasets will contain only one element: the full-graph, without any stochasticity. If training with sampling, then the dataset will contain subgraph samples.

* `graph_spec`, which can be used to create (symbolic) input layer, representing `GraphTensor`, that can be used to build keras model.

In [None]:
# Load our input graph in TF-GNN format, which
# we use to create tf.data.Dataset.
graph_data = datasets.get_in_memory_graph_data(dataset_name)
assert isinstance(graph_data, datasets.NodeClassificationGraphData)
num_classes = graph_data.num_classes()
print('loaded %s in graph_data with num_classes=%i' % (dataset_name, num_classes))



# Prepare datasets: train, test, validate
train_split = ['train']   # might change.
valid_split = 'valid'      # might chang with above.
test_split = 'test'  # constant.
if valid_as_train:
  train_split.append('validation')
  valid_split = 'test'

graph_data = (graph_data
              .with_labels_as_features(True)
              .with_self_loops(True)
              .with_undirected_edges(undirect_graph))
print(graph_data.edge_sets())

# Helpers for in-memory sampling
# Returns map: node set name -> list of edge names outgoing from node set.
def edge_set_names_by_source(
    graph: tfgnn.GraphSchema
    ) -> Mapping[tfgnn.NodeSetName, List[tfgnn.EdgeSetName]]:
  results = collections.defaultdict(list)
  for edge_set_name, edge_set_schema in graph.edge_sets.items():
    results[edge_set_schema.source].append(edge_set_name)
  return results


# Creates SamplingSpec proto by instructing sampler to traverse *every* edge set
# connecting labeled-node, recursively until depth `depth`, sampling `fanout`
# neighbor nodes from *each* edge set.
def make_sampling_spec_with_dfs(
    graph_data: datasets.NodeClassificationGraphData,
    fanout=5, depth=2):
  graph_schema = graph_data.graph_schema()
  edge_sets_by_src_node_set = edge_set_names_by_source(graph_schema)
  spec_builder = sampling_spec_builder.SamplingSpecBuilder(
      graph_schema,
      default_strategy=sampling_spec_builder.SamplingStrategy.RANDOM_UNIFORM)
  spec_builder = spec_builder.seed(graph_data.labeled_nodeset)

  def _recursively_sample_all_edge_sets(
      cur_node_set_name, sampling_step, remaining_depth):
    if remaining_depth == 0:
      return
    for edge_set_name in edge_sets_by_src_node_set[cur_node_set_name]:
      edge_set_schema = graph_schema.edge_sets[edge_set_name]
      _recursively_sample_all_edge_sets(
          edge_set_schema.target,
          sampling_step.sample(fanout, edge_set_name),
          remaining_depth - 1)

  _recursively_sample_all_edge_sets(
      graph_data.labeled_nodeset, spec_builder, depth)

  return spec_builder.build()

# For each edge (i, j), adds edge (j, i), regardless if it already exists.
# This function can be folded into TF-GNN proper, but left here for demonstration.
def undirect_subgraph(graph_tensor: tfgnn.GraphTensor, label) -> (tfgnn.GraphTensor, tf.Tensor):
  edge_sets = {}
  for es_name, es in graph_tensor.edge_sets.items():
    if es.adjacency.source_name == es.adjacency.target_name:
      src, tgt = es.adjacency.source, es.adjacency.target
      node_set_name = es.adjacency.source_name
      new_src = tf.concat([src, tgt], axis=0)
      new_tgt = tf.concat([tgt, src], axis=0)
      edge_sets[es_name] = tfgnn.EdgeSet.from_fields(
          sizes=es.sizes*2,
          adjacency=tfgnn.Adjacency.from_indices(
              source=(node_set_name, new_src),
              target=(node_set_name, new_tgt)))
    else:
      edge_sets[es_name] = es

  graph_tensor = tfgnn.GraphTensor.from_pieces(
      context=graph_tensor.context,
      node_sets=graph_tensor.node_sets,
      edge_sets=edge_sets)
      
  return graph_tensor, label


def as_dataset(graph_data, pop_labels_from_graph: bool = True):
    graph_data = graph_data.with_labels_as_features(True)
    dataset = tf.data.Dataset.from_tensors(graph_data.as_graph_tensor())

    if pop_labels_from_graph:
      num_classes = graph_data.num_classes()
      dataset = dataset.map(
          functools.partial(reader_utils.pop_labels_from_graph, num_classes))

    return dataset


if sampling_style == 'nosampling':
  train_ds = as_dataset(graph_data.with_split(train_split))
  valid_ds = as_dataset(graph_data.with_split(valid_split))
  test_ds = graph_data.with_split(test_split)

  total_train_steps = epochs  # Every step includes all graph nodes & edges.
  train_steps_per_epoch = 1
  validation_steps = 1  # One-step of validation evaluates all validation nodes.
  total_test_steps = 1
elif sampling_style == 'inmemory':
  sampling_spec = make_sampling_spec_with_dfs(graph_data)

  from tensorflow_gnn.experimental.in_memory import int_arithmetic_sampler as ia_sampler
  sampling_args = dict(
      sampling_spec=sampling_spec,
      num_seed_nodes=sampling_num_seed_nodes,
      sampling_mode=ia_sampler.EdgeSampling.WITH_REPLACEMENT)
  sampler_class = ia_sampler.NodeClassificationGraphSampler
  
  train_ds = sampler_class(graph_data.with_split(train_split)).as_dataset(
      **sampling_args).map(undirect_subgraph)
  valid_ds = sampler_class(graph_data.with_split(valid_split)).as_dataset(
      **sampling_args).map(undirect_subgraph)
  test_ds = sampler_class(graph_data.with_split(test_split)).as_dataset(
      **sampling_args).map(undirect_subgraph)
  size_train = graph_data.node_split().train.shape[0]
  train_steps_per_epoch = math.ceil(size_train / sampling_num_seed_nodes)
  validation_steps = 10  # Evaluate small subset of validation.
  total_test_steps = 100
  print('\nsampling_spec:\n', sampling_spec)


# Make sure we can take an example of the dataset.
example = next(iter(train_ds.take(1)))
graph_tensor, labels = example

# Crucially, above `graph_tensor` will be used to create `graph_spec` that will
# be used as a placeholder to instantiate the model.
graph_spec = graph_tensor.spec.relax(num_edges=True, num_nodes=True)

# Model.

In [None]:
# Customize model.
model_prefers_undirected, model = models.make_model_by_name(
      model_name, num_classes, l2_coefficient=l2_regularization,
      model_kwargs=model_kwargs)

node_counts = graph_data.node_counts()

# Input pre-processing layer. Can be trainable.
def set_init_node_features(node_set, node_set_name):
  # If there is a feature already, return it.
  if 'feat' in node_set.features:
    return node_set['feat']

  # Otherwise, we can embed the node_set.
  if embed_featureless_node_dim > 0:
    embedding_layer = tf.keras.layers.Embedding(
            node_counts[node_set_name],
            embed_featureless_node_dim)

# Input is placeholder
input_graph = tf.keras.layers.Input(type_spec=graph_spec)
graph = input_graph

# Transformations.
graph = tfgnn.keras.layers.MapFeatures(  # Pre-processing.
    node_sets_fn=set_init_node_features)(graph)
graph = model(graph)  # GNN.

# Readout tf.Tensor.
h = reader_utils.readout_seed_node_features(
    graph, node_set_name=graph_data.labeled_nodeset)

# Post-processing (optional)
# h = tf.keras.layers.Dense(num_classes)(h)

# Capture the computation graph input_graph -> h as a `tf.keras.Model`.
keras_model = tf.keras.Model(inputs=input_graph, outputs=h)

# Optimization

In [None]:
# Optimizer Function
# optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)

opt =  tf.keras.optimizers.Adam(learning_rate)
loss = tf.keras.losses.CategoricalCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

keras_model.compile(opt, loss=loss, metrics=['acc'])

In [None]:
keras_model.fit(
      train_ds,
      epochs=epochs,
      steps_per_epoch=train_steps_per_epoch,
      validation_data=valid_ds.repeat(),
      validation_steps=validation_steps,
      validation_freq=eval_every)

## Test model performance

In [None]:
total_correct = 0
total_size = 0
total_test_steps = 100
test_labels = graph_data.test_labels()

seed_feat_name = 'seed_nodes.' + graph_data.labeled_nodeset

for test_step, (graph, unused_labels) in enumerate(test_ds):
  out = keras_model(graph)

  seed_pos = graph.context[seed_feat_name]
  id_feat = graph.node_sets[graph_data.labeled_nodeset]['#id']
  seed_ids = tf.gather(id_feat , seed_pos)
  labels = tf.gather(test_labels, seed_ids)  
  pred_y = tf.cast(tf.argmax(out, -1), labels.dtype)


  is_correct_label = pred_y == labels
  int_is_correct_label = tf.cast(is_correct_label, tf.int64)
  total_correct += tf.reduce_sum(int_is_correct_label)
  total_size += tf.reduce_sum(tf.ones_like(int_is_correct_label))
  if test_step >= total_test_steps:
    break

  print('Test accuracy = %f' % (
      tf.cast(total_correct, tf.float32) / tf.cast(total_size, tf.float32)))

## Your assignment
Do simple change: (# layers, regularization, hidden dimensions)

Try ResNet.

## Bonus: Import your custom dataset
This works for datasets that fit in-memory:

https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/experimental/in_memory/datasets.py