# Build NN

In [1]:
import abc
from typing import Any, Callable, List, Optional, Tuple

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


_Array = chex.Array
_Fn = Callable[..., Any]
BIG_NUMBER = 1e6
PROCESSOR_TAG = 'clrs_processor'


class Processor(hk.Module):
  """Processor abstract base class."""

  def __init__(self, name: str):
    if not name.endswith(PROCESSOR_TAG):
      name = name + '_' + PROCESSOR_TAG
    super().__init__(name=name)

  @abc.abstractmethod
  def __call__(
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **kwargs,
  ) -> Tuple[_Array, Optional[_Array]]:
    """Processor inference step.

    Args:
      node_fts: Node features.
      edge_fts: Edge features.
      graph_fts: Graph features.
      adj_mat: Graph adjacency matrix.
      hidden: Hidden features.
      **kwargs: Extra kwargs.

    Returns:
      Output of processor inference step as a 2-tuple of (node, edge)
      embeddings. The edge embeddings can be None.
    """
    pass

  @property
  def inf_bias(self):
    return False

  @property
  def inf_bias_edge(self):
    return False


class GAT(Processor):
  """Graph Attention Network (Velickovic et al., ICLR 2018)."""

  def __init__(
      self,
      out_size: int,
      nb_heads: int,
      activation: Optional[_Fn] = jax.nn.relu,
      residual: bool = True,
      use_ln: bool = False,
      name: str = 'gat_aggr',
  ):
    super().__init__(name=name)
    self.out_size = out_size
    self.nb_heads = nb_heads
    if out_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the width!')
    self.head_size = out_size // nb_heads
    self.activation = activation
    self.residual = residual
    self.use_ln = use_ln

  def __call__(  # pytype: disable=signature-mismatch  # numpy-scalars
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **unused_kwargs,
  ) -> _Array:
    """GAT inference step."""

    b, n, _ = node_fts.shape
    assert edge_fts.shape[:-1] == (b, n, n)
    assert graph_fts.shape[:-1] == (b,)
    assert adj_mat.shape == (b, n, n)

    z = jnp.concatenate([node_fts, hidden], axis=-1)
    m = hk.Linear(self.out_size)
    skip = hk.Linear(self.out_size)

    bias_mat = (adj_mat - 1.0) * 1e9
    bias_mat = jnp.tile(bias_mat[..., None],
                        (1, 1, 1, self.nb_heads))     # [B, N, N, H]
    bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2))  # [B, H, N, N]

    a_1 = hk.Linear(self.nb_heads)
    a_2 = hk.Linear(self.nb_heads)
    a_e = hk.Linear(self.nb_heads)
    a_g = hk.Linear(self.nb_heads)

    values = m(z)                                      # [B, N, H*F]
    values = jnp.reshape(
        values,
        values.shape[:-1] + (self.nb_heads, self.head_size))  # [B, N, H, F]
    values = jnp.transpose(values, (0, 2, 1, 3))              # [B, H, N, F]

    att_1 = jnp.expand_dims(a_1(z), axis=-1)
    att_2 = jnp.expand_dims(a_2(z), axis=-1)
    att_e = a_e(edge_fts)
    att_g = jnp.expand_dims(a_g(graph_fts), axis=-1)

    logits = (
        jnp.transpose(att_1, (0, 2, 1, 3)) +  # + [B, H, N, 1]
        jnp.transpose(att_2, (0, 2, 3, 1)) +  # + [B, H, 1, N]
        jnp.transpose(att_e, (0, 3, 1, 2)) +  # + [B, H, N, N]
        jnp.expand_dims(att_g, axis=-1)       # + [B, H, 1, 1]
    )                                         # = [B, H, N, N]
    coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1)
    ret = jnp.matmul(coefs, values)  # [B, H, N, F]
    ret = jnp.transpose(ret, (0, 2, 1, 3))  # [B, N, H, F]
    ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,))  # [B, N, H*F]

    if self.residual:
      ret += skip(z)

    if self.activation is not None:
      ret = self.activation(ret)

    if self.use_ln:
      ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
      ret = ln(ret)

    return ret, None  # pytype: disable=bad-return-type  # numpy-scalars


class GATFull(GAT):
  """Graph Attention Network with full adjacency matrix."""

  def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
               adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
    adj_mat = jnp.ones_like(adj_mat)
    return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)


class GATv2(Processor):
  """Graph Attention Network v2 (Brody et al., ICLR 2022)."""

  def __init__(
      self,
      out_size: int,
      nb_heads: int,
      mid_size: Optional[int] = None,
      activation: Optional[_Fn] = jax.nn.relu,
      residual: bool = True,
      use_ln: bool = False,
      name: str = 'gatv2_aggr',
  ):
    super().__init__(name=name)
    if mid_size is None:
      self.mid_size = out_size
    else:
      self.mid_size = mid_size
    self.out_size = out_size
    self.nb_heads = nb_heads
    if out_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the width!')
    self.head_size = out_size // nb_heads
    if self.mid_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the message!')
    self.mid_head_size = self.mid_size // nb_heads
    self.activation = activation
    self.residual = residual
    self.use_ln = use_ln

  def __call__(  # pytype: disable=signature-mismatch  # numpy-scalars
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **unused_kwargs,
  ) -> _Array:
    """GATv2 inference step."""

    b, n, _ = node_fts.shape
    assert edge_fts.shape[:-1] == (b, n, n)
    assert graph_fts.shape[:-1] == (b,)
    assert adj_mat.shape == (b, n, n)

    z = jnp.concatenate([node_fts, hidden], axis=-1)
    m = hk.Linear(self.out_size)
    skip = hk.Linear(self.out_size)

    bias_mat = (adj_mat - 1.0) * 1e9
    bias_mat = jnp.tile(bias_mat[..., None],
                        (1, 1, 1, self.nb_heads))     # [B, N, N, H]
    bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2))  # [B, H, N, N]

    w_1 = hk.Linear(self.mid_size)
    w_2 = hk.Linear(self.mid_size)
    w_e = hk.Linear(self.mid_size)
    w_g = hk.Linear(self.mid_size)

    a_heads = []
    for _ in range(self.nb_heads):
      a_heads.append(hk.Linear(1))

    values = m(z)                                      # [B, N, H*F]
    values = jnp.reshape(
        values,
        values.shape[:-1] + (self.nb_heads, self.head_size))  # [B, N, H, F]
    values = jnp.transpose(values, (0, 2, 1, 3))              # [B, H, N, F]

    pre_att_1 = w_1(z)
    pre_att_2 = w_2(z)
    pre_att_e = w_e(edge_fts)
    pre_att_g = w_g(graph_fts)

    pre_att = (
        jnp.expand_dims(pre_att_1, axis=1) +     # + [B, 1, N, H*F]
        jnp.expand_dims(pre_att_2, axis=2) +     # + [B, N, 1, H*F]
        pre_att_e +                              # + [B, N, N, H*F]
        jnp.expand_dims(pre_att_g, axis=(1, 2))  # + [B, 1, 1, H*F]
    )                                            # = [B, N, N, H*F]

    pre_att = jnp.reshape(
        pre_att,
        pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size)
    )  # [B, N, N, H, F]

    pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4))  # [B, H, N, N, F]

    # This part is not very efficient, but we agree to keep it this way to
    # enhance readability, assuming `nb_heads` will not be large.
    logit_heads = []
    for head in range(self.nb_heads):
      logit_heads.append(
          jnp.squeeze(
              a_heads[head](jax.nn.leaky_relu(pre_att[:, head])),
              axis=-1)
      )  # [B, N, N]

    logits = jnp.stack(logit_heads, axis=1)  # [B, H, N, N]

    coefs = jax.nn.softmax(logits + bias_mat, axis=-1)
    ret = jnp.matmul(coefs, values)  # [B, H, N, F]
    ret = jnp.transpose(ret, (0, 2, 1, 3))  # [B, N, H, F]
    ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,))  # [B, N, H*F]

    if self.residual:
      ret += skip(z)

    if self.activation is not None:
      ret = self.activation(ret)

    if self.use_ln:
      ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
      ret = ln(ret)

    return ret, None  # pytype: disable=bad-return-type  # numpy-scalars


class GATv2FullD2(GATv2):
  """Graph Attention Network v2 with full adjacency matrix and D2 symmetry."""

  def d2_forward(self,
                 node_fts: List[_Array],
                 edge_fts: List[_Array],
                 graph_fts: List[_Array],
                 adj_mat: _Array,
                 hidden: _Array,
                 **unused_kwargs) -> List[_Array]:
    num_d2_actions = 4

    d2_inverses = [
        0, 1, 2, 3  # All members of D_2 are self-inverses!
    ]

    d2_multiply = [
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [2, 3, 0, 1],
        [3, 2, 1, 0],
    ]

    assert len(node_fts) == num_d2_actions
    assert len(edge_fts) == num_d2_actions
    assert len(graph_fts) == num_d2_actions

    ret_nodes = []
    adj_mat = jnp.ones_like(adj_mat)

    for g in range(num_d2_actions):
      emb_values = []
      for h in range(num_d2_actions):
        gh = d2_multiply[d2_inverses[g]][h]
        node_features = jnp.concatenate(
            (node_fts[g], node_fts[gh]),
            axis=-1)
        edge_features = jnp.concatenate(
            (edge_fts[g], edge_fts[gh]),
            axis=-1)
        graph_features = jnp.concatenate(
            (graph_fts[g], graph_fts[gh]),
            axis=-1)
        cell_embedding = super().__call__(
            node_fts=node_features,
            edge_fts=edge_features,
            graph_fts=graph_features,
            adj_mat=adj_mat,
            hidden=hidden
        )
        emb_values.append(cell_embedding[0])
      ret_nodes.append(
          jnp.mean(jnp.stack(emb_values, axis=0), axis=0)
      )

    return ret_nodes


class GATv2Full(GATv2):
  """Graph Attention Network v2 with full adjacency matrix."""

  def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
               adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
    adj_mat = jnp.ones_like(adj_mat)
    return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)

# main

# Assuming you want a batch size of 1 for simplicity
batch_size = 1

# Number of players, features, and latent features
n_players = 10
n_node_features = 5
n_edge_features = 3
n_graph_features = 2
n_latent_features = 4

# Add a batch dimension to your data
node_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_players, n_node_features)), axis=0) for _ in range(4)]
edge_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_players, n_players, n_edge_features)), axis=0) for _ in range(4)]
graph_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_graph_features)), axis=0) for _ in range(4)]
adj_mat = jnp.expand_dims(jnp.array(np.random.randint(0, 2, (n_players, n_players))), axis=0)
hidden = jnp.expand_dims(jnp.zeros((n_players, n_latent_features)), axis=0)


def forward(node_fts, edge_fts, graph_fts, adj_mat, hidden):
    # Instantiate your GATv2FullD2 here, within the function.
    out_size = 128  
    nb_heads = 8
    model = GATv2FullD2(out_size=out_size, nb_heads=nb_heads)
    return model.d2_forward(node_fts, edge_fts, graph_fts, adj_mat, hidden)

# Transform the function using Haiku
transformed_forward = hk.transform(forward)

# Create a JAX random key for initialization
rng = jax.random.PRNGKey(42)

# Initialize the model (parameters)
params = transformed_forward.init(rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

# Apply the model
output = transformed_forward.apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)
print(output)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[Array([[[0.4840707 , 0.        , 0.        , ..., 0.        ,
         0.02204512, 0.        ],
        [0.0693898 , 0.        , 0.        , ..., 0.        ,
         0.02279956, 0.        ],
        [0.344575  , 0.        , 0.        , ..., 0.        ,
         0.01206323, 0.        ],
        ...,
        [0.6926438 , 0.        , 0.        , ..., 0.        ,
         0.1365541 , 0.00513582],
        [0.63773996, 0.        , 0.        , ..., 0.        ,
         0.04622685, 0.        ],
        [0.85276425, 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]]], dtype=float32), Array([[[0.7025213 , 0.        , 0.        , ..., 0.        ,
         0.07945526, 0.        ],
        [0.8328907 , 0.        , 0.        , ..., 0.        ,
         0.18241775, 0.        ],
        [0.45594347, 0.        , 0.        , ..., 0.        ,
         0.02872106, 0.        ],
        ...,
        [0.01685721, 0.        , 0.        , ..., 0.        ,
         0.        , 0.      

# Main

In [1]:
import clrs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sampler, spec = clrs.build_sampler(
    name='bfs',
    seed=42,
    num_samples=1000,
    length=16)

In [None]:
from clrs.examples.utils import *

In [None]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

In [None]:
# Create samplers
(train_samplers,
val_samplers, val_sample_counts,
test_samplers, test_sample_counts,
spec_list) = create_samplers(rng, train_lengths)

In [None]:
train_lengths

# create_samplers

In [None]:
# def create_samplers(rng, train_lengths: List[int]):
"""Create all the samplers."""
train_samplers = []
val_samplers = []
val_sample_counts = []
test_samplers = []
test_sample_counts = []
spec_list = []

In [None]:
filename = '/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/activity_selector_test/1.0.0/clrs_dataset-test.tfrecord-00000-of-00001'
data = tf.io.read_file(filename)

In [None]:
import tensorflow_datasets as tfds
folder = '/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/activity_selector_test'
dataset = tfds.load(folder)