Skip to content

Commit

Permalink
Change typechecks for pieces of tfgnn.keras.layers.GraphUpdate:
Browse files Browse the repository at this point in the history
 - Check isinstance(arg, tf.keras.layers.Layer) at runtime.
 - Type annotations refer to Protocols for the call() signature.
   NOTE: This is not checked yet, google/pytype#81.

PiperOrigin-RevId: 405832960
  • Loading branch information
Graph Learning Team authored and phanein committed Nov 4, 2021
1 parent e799500 commit d44a927
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 68 deletions.
4 changes: 2 additions & 2 deletions tensorflow_gnn/graph/keras/builders.py
Expand Up @@ -51,8 +51,8 @@ class ConvGNNBuilder:
"""

def __init__(
self, convolutions_factory: Callable[[const.EdgeSetName],
graph_update_lib.EdgesToNodePooling],
self, convolutions_factory: Callable[
[const.EdgeSetName], graph_update_lib.EdgesToNodePoolingLayer],
nodes_next_state_factory: Callable[[const.NodeSetName],
next_state_lib.NextStateForNodeSet]):
self._convolutions_factory = convolutions_factory
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/graph/keras/layers/convolutions.py
Expand Up @@ -36,7 +36,7 @@ class ConvolutionFromEdgeSetUpdate(tf.keras.layers.Layer):
"""

def __init__(self,
edge_set_update: graph_update.AbstractEdgeSetUpdate,
edge_set_update: graph_update.EdgeSetUpdateLayer,
*,
destination_tag: const.IncidentNodeTag = const.TARGET,
reduce_type: str = "sum",
Expand Down
190 changes: 125 additions & 65 deletions tensorflow_gnn/graph/keras/layers/graph_update.py
@@ -1,5 +1,6 @@
"""The GraphUpdate layer and its pieces."""

import sys
from typing import Any, Callable, Mapping, Optional, Sequence

import tensorflow as tf
Expand All @@ -9,63 +10,94 @@
from tensorflow_gnn.graph import graph_tensor_ops as ops
from tensorflow_gnn.graph.keras.layers import next_state as next_state_lib

# pylint:disable=g-import-not-at-top
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
# pylint:enable=g-import-not-at-top


# This file defines the canonical implementations of EdgeSetUpdate,
# NodeSetUpdate, and ContextUpdate. However, users are free to pass objects
# of their own compatible reimplementations as long as they adhere to
# the following interfaces:
#
# class AbstractEdgeSetUpdate(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor, *, edge_set_name: EdgeSetName
# ) -> FieldOrFields: # Results shaped like edge features.
# raise NotImplementedError()
#
# class AbstractNodeSetUpdate(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor, *, node_set_name: NodeSetName
# ) -> FieldOrFields: # Results shaped like node features.
# raise NotImplementedError()
#
# class AbstractContextUpdate(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor
# ) -> FieldOrFields: # Results shaped like context features.
# raise NotImplementedError()
#
# For now, we don't want to require that custom classes of these kinds inherit
# from abstract interfaces. The best we can do in type annotations is to use
# the names above, reminding the programmer of the interfaces described there.
AbstractEdgeSetUpdate = tf.keras.layers.Layer
AbstractNodeSetUpdate = tf.keras.layers.Layer
AbstractContextUpdate = tf.keras.layers.Layer
# of their own compatible reimplementations as long as they:
# 1. subclass tf.keras.layers.Layer,
# 2. provide the signatures in call() (and hence __call__()) given below.
# We use the following type names to express that, and we want some docstrings
# attached to them, but we do not want to require users to subclass an abstract
# interface. Hence we resort to Protocols to check item 2 and leave it to
# runtime checks to check item 1 (which is the less surprising one in a Keras
# environment).
# NOTE: Item 2 is not checked yet, https://github.com/google/pytype/issues/81.
class EdgeSetUpdateLayer(Protocol):
"""A Keras layer that can be called like the standard EdgeSetUpdate."""

def call(self, graph: gt.GraphTensor, *,
edge_set_name: const.EdgeSetName) -> const.FieldOrFields:
"""Returns field(s) shaped like edge features."""
...


class NodeSetUpdateLayer(Protocol):
"""A Keras layer that can be called like the standard NodeSetUpdate."""

def call(self, graph: gt.GraphTensor, *,
node_set_name: const.NodeSetName) -> const.FieldOrFields:
"""Returns field(s) shaped like node features."""
...


class ContextUpdateLayer(Protocol):
"""A Keras layer that can be called like the standard ContextUpdate."""

def call(self, graph: gt.GraphTensor) -> const.FieldOrFields:
"""Returns field(s) shaped like context features."""
...


# The NodeSetUpdate and ContextUpdate layers are initialized with maps of
# input layers from those graph pieces that are in a many-to-one relation
# with the updated graph piece (e.g., many incoming edges per node, many
# nodes per graph component). There is a variety of such input layers,
# including user-defined ones, and they are expected to provide the
# following interfaces:
#
# class EdgesToNodePooling(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor, *, edge_set_name: EdgeSetName
# ) -> FieldOrFields: # Results shaped like node features.
# raise NotImplementedError()
#
# class NodesToContextPooling(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor, *, node_set_name: NodeSetName
# ) -> FieldOrFields: # Results shaped like context features.
# raise NotImplementedError()
#
# class EdgesToContextPooling(tf.keras.layers.Layer):
# def call(self, graph: GraphTensor, *, edge_set_name: EdgeSetName
# ) -> FieldOrFields: # Results shaped like context features.
# raise NotImplementedError()
#
# For now, we don't want to require that classes of these kinds inherit
# from abstract interfaces. The best we can do in type annotations is to use
# the names above, reminding the programmer of the interfaces described there.
EdgesToNodePooling = tf.keras.layers.Layer
NodesToContextPooling = tf.keras.layers.Layer
EdgesToContextPooling = tf.keras.layers.Layer
# including user-defined ones, and they are required to
# 1. subclass tf.keras.layers.Layer,
# 2. provide the signatures in call() (and hence __call__()) listed below.
# We use the following type names to express that, with Protocols as above.
class EdgesToNodePoolingLayer(Protocol):
"""A Keras layer for input from an EdgeSet into a NodeSetUpdate.
Typical implementations of this protocol are:
* Convolutions, which propagate state from adjacent nodes along the edge set
and pool it for the destination node. They may use edge features,
but do not update them (that is, the edge set has no evolving state).
* Edge state poolings, which pool already-computed states from incident
edges of the edge set for the destination node. Using these in a
NodeSetUpdate typically requires a corresponding EdgeSetUpdate
in the same GraphUpdate.
"""

def call(self, graph: gt.GraphTensor, *,
edge_set_name: const.EdgeSetName) -> const.FieldOrFields:
"""Returns field(s) shaped like node features."""
...


class NodesToContextPoolingLayer(Protocol):
"""A Keras layer for input from a NodeSet into a ContextUpdate."""

def call(self, graph: gt.GraphTensor, *,
node_set_name: const.NodeSetName) -> const.FieldOrFields:
"""Returns field(s) shaped like context features."""
...


class EdgesToContextPoolingLayer(Protocol):
"""A Keras layer for input from an EdgeSet into a ContextUpdate."""

def call(self, graph: gt.GraphTensor, *,
edge_set_name: const.EdgeSetName) -> const.FieldOrFields:
"""Returns field(s) shaped like context features."""
...


@tf.keras.utils.register_keras_serializable(package="GNN")
Expand Down Expand Up @@ -114,10 +146,10 @@ class GraphUpdate(tf.keras.layers.Layer):
def __init__(self,
*,
edge_sets: Optional[Mapping[const.EdgeSetName,
AbstractEdgeSetUpdate]] = None,
EdgeSetUpdateLayer]] = None,
node_sets: Optional[Mapping[const.NodeSetName,
AbstractNodeSetUpdate]] = None,
context: Optional[AbstractContextUpdate] = None,
NodeSetUpdateLayer]] = None,
context: Optional[ContextUpdateLayer] = None,
deferred_init_callback: Optional[
Callable[[gt.GraphTensorSpec], Mapping[str, Any]]] = None,
**kwargs):
Expand All @@ -138,9 +170,17 @@ def __init__(self,
self._context_update = None

def _init_from_updates(self, edge_sets=None, node_sets=None, context=None):
self._edge_set_updates = dict(edge_sets or {})
self._node_set_updates = dict(node_sets or {})
self._context_update = context
self._edge_set_updates = {
key: _check_is_layer(value, f"GraphUpdate(edge_sets={{{key}: ...}}")
for key, value in (edge_sets or {}).items()}
self._node_set_updates = {
key: _check_is_layer(value, f"GraphUpdate(node_sets={{{key}: ...}}")
for key, value in (node_sets or {}).items()}
if context is not None:
self._context_update = _check_is_layer(context,
"GraphUpdate(context=...)")
else:
self._context_update = None
self._is_initialized = True

def get_config(self):
Expand Down Expand Up @@ -230,7 +270,8 @@ def __init__(self,
context_input_feature: Optional[const.FieldNameOrNames] = None,
**kwargs):
super().__init__(**kwargs)
self._next_state = next_state
self._next_state = _check_is_layer(next_state,
"EdgeSetUpdate(next_state=...)")
self._edge_input_feature = _copy_if_sequence(edge_input_feature)
self._node_input_tags = list(node_input_tags)
if isinstance(node_input_feature, (list, tuple, set)):
Expand Down Expand Up @@ -308,16 +349,21 @@ class NodeSetUpdate(tf.keras.layers.Layer):
"""

def __init__(self,
edge_set_inputs: Mapping[const.EdgeSetName, EdgesToNodePooling],
edge_set_inputs: Mapping[const.EdgeSetName,
EdgesToNodePoolingLayer],
next_state: next_state_lib.NextStateForNodeSet,
*,
node_input_feature: Optional[const.FieldNameOrNames]
= const.DEFAULT_STATE_NAME,
context_input_feature: Optional[const.FieldNameOrNames] = None,
**kwargs):
super().__init__(**kwargs)
self._edge_set_inputs = dict(edge_set_inputs)
self._next_state = next_state
self._edge_set_inputs = {
key: _check_is_layer(value,
f"NodeSetUpdate(edge_set_inputs={{{key}: ...}}")
for key, value in (edge_set_inputs).items()}
self._next_state = _check_is_layer(next_state,
"NodeSetUpdate(next_state=...")
self._node_input_feature = _copy_if_sequence(node_input_feature)
self._context_input_feature = _copy_if_sequence(context_input_feature)

Expand Down Expand Up @@ -384,20 +430,27 @@ class ContextUpdate(tf.keras.layers.Layer):
"""

def __init__(self,
node_set_inputs: Mapping[const.NodeSetName,
NodesToContextPooling],
node_set_inputs: Mapping[
const.NodeSetName, NodesToContextPoolingLayer],
next_state: next_state_lib.NextStateForContext,
*,
edge_set_inputs: Optional[Mapping[const.EdgeSetName,
EdgesToContextPooling]] = None,
edge_set_inputs: Optional[Mapping[
const.EdgeSetName, EdgesToContextPoolingLayer]] = None,
context_input_feature: Optional[const.FieldNameOrNames]
= const.DEFAULT_STATE_NAME,
**kwargs):
super().__init__(**kwargs)
self._node_set_inputs = dict(node_set_inputs)
self._next_state = next_state
self._node_set_inputs = {
key: _check_is_layer(value,
f"ContextUpdate(node_set_inputs={{{key}: ...}}")
for key, value in (node_set_inputs).items()}
self._next_state = _check_is_layer(next_state,
"ContextUpdate(next_state=...)")
if edge_set_inputs is not None:
self._edge_set_inputs = dict(edge_set_inputs)
self._edge_set_inputs = {
key: _check_is_layer(value,
f"ContextUpdate(edge_set_inputs={{{key}: ...}}")
for key, value in (edge_set_inputs).items()}
else:
self._edge_set_inputs = None
self._context_input_feature = _copy_if_sequence(context_input_feature)
Expand Down Expand Up @@ -456,3 +509,10 @@ def _get_feature_or_features(features, names):
return {}
else:
return {name: features[name] for name in names}


def _check_is_layer(obj, description):
if not isinstance(obj, tf.keras.layers.Layer):
raise ValueError(f"{description} must be a tf.keras.layer.Layer, "
f"got type: {type(obj).__name__}")
return obj

0 comments on commit d44a927

Please sign in to comment.