# Instructions

Make a copy of this notebook, and save with your initials at the end, so that we do not overwrite each other's.



In [None]:
!pip install tensorflow_gnn --pre
!pip install tensorflow_ranking

In [2]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_gnn as tfgnn
import tensorflow_ranking as tfr
import argparse
import tqdm
import collections
import functools
import hashlib
import io
import os
from typing import NamedTuple
from absl import flags



In [None]:
### TILES
class FakeBoolFlag(NamedTuple):
    value: bool
_TOY_DATA = FakeBoolFlag(value=False)


class TileExample(NamedTuple):
  """Single example of tile graph."""
  node_features: tf.Tensor
  node_ops: tf.Tensor
  edges: tf.Tensor
  config_features: tf.Tensor
  config_runtimes: tf.Tensor
  config_runtime_normalizers: tf.Tensor
  tile_id: tf.Tensor
  total_nodes: tf.Tensor
  total_edges: tf.Tensor
  total_configs: tf.Tensor

  def to_graph_tensor(
      self, config_samples: int = -1,
      normalize_runtimes: bool = True) -> tfgnn.GraphTensor:
    """Packages instance tensors (edges, features) into `GraphTensor`.

    Args:
      config_samples: if -1, then all module configurations (and their runtimes)
        are returned. If >=0, then this many module configurations (and their
        corresponding runtimes) are sampled uniformly at random.
      normalize_runtimes: If set (default), runtimes will be normalized by
        dividing over the runtime of "default tile size" (to account for worker
        machine differences).

    Returns:
      GraphTensor with node-sets:
        + op (feats='op': int-categorical, 'feats': float-vector).
          This is the only "real" graph node"
        + configs (feats='feats': float-vector, 'runtimes': float scalar,
                   'normalizers': float scalar).
          These are "fake" nodes. There will be one node per configuration.
        + g (stands for "graph") has one (root) node connecting to all op and
          config nodes.
      and edge-sets:
        + 'feed': directed edges connecting op-node to op-node.
        + 'g_op': edges connecting the singleton "g" node to every "op" node.
        + 'g_config': connecting the singleton "g" node to every "config" node.
    """
    config_features = self.config_features
    config_runtimes = self.config_runtimes
    config_runtime_normalizers = self.config_runtime_normalizers
    num_configs = tf.shape(config_features)[0]

    # If sampling is requested.
    if config_samples >= 0:
      rnd = tf.random.shuffle(tf.range(num_configs, dtype=tf.int32))
      rnd = rnd[:config_samples]
      config_features = tf.gather(config_features, rnd)
      config_runtimes = tf.gather(config_runtimes, rnd)
      config_runtime_normalizers = tf.gather(config_runtime_normalizers, rnd)
      num_configs = tf.shape(config_features)[0]

    if normalize_runtimes:
      config_runtimes /= config_runtime_normalizers

    return tfgnn.GraphTensor.from_pieces(
        node_sets={
            'op': tfgnn.NodeSet.from_fields(
                sizes=tf.expand_dims(self.total_nodes, 0),
                features={
                    'op': self.node_ops,
                    'feats': self.node_features,
                }
            ),
            'config': tfgnn.NodeSet.from_fields(
                features={
                    'feats': config_features,
                    'runtimes': config_runtimes,
                    'normalizers': config_runtime_normalizers,
                },
                sizes=tf.expand_dims(num_configs, 0),
            ),
            'g': tfgnn.NodeSet.from_fields(
                features={'tile_id': tf.expand_dims(self.tile_id, 0)},
                sizes=tf.constant([1]))
        },
        edge_sets={
            'feed': tfgnn.EdgeSet.from_fields(
                sizes=tf.expand_dims(self.total_edges, 0),
                adjacency=tfgnn.Adjacency.from_indices(
                    source=('op', self.edges[:, 0]),
                    target=('op', self.edges[:, 1]))),
            'g_op': tfgnn.EdgeSet.from_fields(
                sizes=tf.expand_dims(self.total_nodes, 0),
                adjacency=tfgnn.Adjacency.from_indices(
                    source=('g', tf.zeros([self.total_nodes], dtype=tf.int32)),
                    target=('op', tf.range(self.total_nodes, dtype=tf.int32)))),
            'g_config': tfgnn.EdgeSet.from_fields(
                sizes=tf.expand_dims(num_configs, 0),
                adjacency=tfgnn.Adjacency.from_indices(
                    source=('g', tf.zeros([num_configs], dtype=tf.int32)),
                    target=('config', tf.range(num_configs, dtype=tf.int32)))),
        })


class NpzDatasetPartition:
  """Holds one data partition (train, test, validation) on device memory."""

  def __init__(self):
    # Populated in `add_npz_file()`.
    self._data_dict: dict[str, list[np.ndarray]] = collections.defaultdict(list)
    self._num_edges: list[int] = [0]    # prepend with 0 to prep for cumsum.
    self._num_configs: list[int] = [0]  # ^^
    self._num_nodes: list[int] = [0]    # ^^

    # Populated in `finalize()`.
    self.node_feat: 'tf.Tensor | None' = None   # indexed by node_ranges.
    self.node_opcode: 'tf.Tensor | None' = None  # ^^
    self.edge_index: 'tf.Tensor | None' = None   # indexed by edge_ranges.
    self.config_feat: 'tf.Tensor | None' = None      # indexed by config_ranges.
    self.config_runtime: 'tf.Tensor | None' = None   # ^^
    self.config_runtime_normalizers: 'tf.Tensor | None' = None  # ^^
    self.tile_id: 'tf.Tensor | None' = None

    # finalize() sets to: cumsum([0, numEdges(graph_1), numEdges(graph_2), ..]).
    self.edge_ranges: 'tf.Tensor | None' = None
    # finalize() sets to: cumsum([0, numNodes(graph_1), numNodes(graph_2), ..]).
    self.node_ranges: 'tf.Tensor | None' = None
    # finalize() sets to: cumsum([0, numModules(graph_1), nModul(graph_2), ..]).
    self.config_ranges: 'tf.Tensor | None' = None

  def save_to_file(self, cache_file: str):
    """Saves dataset as numpy. Can be restored with `load_from_file`."""
    assert self.node_feat is not None, 'finalize() was not invoked'
    assert self.node_opcode is not None
    assert self.edge_index is not None
    assert self.config_feat is not None
    assert self.config_runtime is not None
    assert self.config_runtime_normalizers is not None
    assert self.tile_id is not None
    assert self.edge_ranges is not None
    assert self.node_ranges is not None
    assert self.config_ranges is not None
    np_dict = dict(
        node_feat=self.node_feat.numpy(),
        node_opcode=self.node_opcode.numpy(),
        edge_index=self.edge_index.numpy(),
        config_feat=self.config_feat.numpy(),
        config_runtime=self.config_runtime.numpy(),
        config_runtime_normalizers=self.config_runtime_normalizers.numpy(),
        edge_ranges=self.edge_ranges.numpy(),
        node_ranges=self.node_ranges.numpy(),
        config_ranges=self.config_ranges.numpy()
    )
    bytes_io = io.BytesIO()
    np.savez_compressed(bytes_io, **np_dict)
    with tf.io.gfile.GFile(cache_file, 'wb') as fout:
      fout.write(bytes_io.getvalue())
    print('wrote ' + cache_file)
    tile_ids_file = cache_file + '.tiles.txt'
    with tf.io.gfile.GFile(tile_ids_file, 'w') as fout:
      fout.write(b'\n'.join(self.tile_id.numpy().tolist()).decode())
    print('wrote ' + tile_ids_file)

  def load_from_file(self, cache_file: str):
    """Loads dataset from numpy file."""
    np_dict = np.load(tf.io.gfile.GFile(cache_file, 'rb'))
    self.node_feat = tf.constant(np_dict['node_feat'])
    self.node_opcode = tf.constant(np_dict['node_opcode'])
    self.edge_index = tf.constant(np_dict['edge_index'])
    self.config_feat = tf.constant(np_dict['config_feat'])
    self.config_runtime = tf.constant(np_dict['config_runtime'])
    self.config_runtime_normalizers = tf.constant(
        np_dict['config_runtime_normalizers'])
    self.edge_ranges = tf.constant(np_dict['edge_ranges'])
    self.node_ranges = tf.constant(np_dict['node_ranges'])
    self.config_ranges = tf.constant(np_dict['config_ranges'])
    tile_ids = tf.io.gfile.GFile(cache_file + '.tiles.txt', 'r').readlines()
    self.tile_id = tf.stack([tile_id.rstrip() for tile_id in tile_ids])
    print('loaded from ' + cache_file)

  def add_npz_file(
      self, tile_id: str, npz_file: np.lib.npyio.NpzFile, min_configs: int = 2):
    """Copies data from npz file into this class instance.

    After finishing all calls `add_npz_file()`, user must invoke `finalize()`.

    Args:
      tile_id: the filename (without extension) that npz_file was read from.
      npz_file: Output of np.load on a file from the TpuGraphs Tiles dataset.
      min_configs: The file be incorporated only if the number of module
        configurations is equal or greater than this.
    """
    npz_data = dict(npz_file.items())
    #num_configs = npz_data['config_feat'].shape[0]
    num_configs = npz_data['node_config_feat'].shape[0]
    if num_configs < min_configs:
      print('skipping tile with only %i configurations' % num_configs)
      return
    for key, ndarray in npz_data.items():
      self._data_dict[key].append(ndarray)
    self._data_dict['tile_id'].append(np.array(tile_id))
    num_nodes = npz_data['node_feat'].shape[0]
    num_edges = npz_data['edge_index'].shape[0]
    assert num_nodes == npz_data['node_opcode'].shape[0]
    assert num_configs == npz_data['config_runtime'].shape[0]
    #assert num_configs == npz_data['config_runtime_normalizers'].shape[0]
    assert num_configs == npz_data['config_runtime'].shape[0]
    self._num_nodes.append(num_nodes)
    self._num_edges.append(num_edges)
    self._num_configs.append(num_configs)

  def finalize(self):
    self.tile_id = tf.stack(self._data_dict['tile_id'], axis=0)
    self.node_feat = tf.concat(self._data_dict['node_feat'], axis=0)
    self.node_opcode = tf.concat(self._data_dict['node_opcode'], axis=0)
    self.edge_index = tf.concat(self._data_dict['edge_index'], axis=0)
    self.config_feat = tf.concat(self._data_dict['config_feat'], axis=0)
    self.config_runtime = tf.concat(self._data_dict['config_runtime'], axis=0)
    self.config_runtime_normalizers = tf.concat(
        self._data_dict['config_runtime_normalizers'], axis=0)
    self.edge_ranges = tf.cumsum(self._num_edges)
    self.node_ranges = tf.cumsum(self._num_nodes)
    self.config_ranges = tf.cumsum(self._num_configs)

  def get_item(self, index: int) -> TileExample:
    node_start = self.node_ranges[index]
    node_end = self.node_ranges[index + 1]
    edge_start = self.edge_ranges[index]
    edge_end = self.edge_ranges[index + 1]
    config_start = self.config_ranges[index]
    config_end = self.config_ranges[index + 1]

    return TileExample(
        node_features=self.node_feat[node_start:node_end],
        node_ops=self.node_opcode[node_start:node_end],
        edges=self.edge_index[edge_start:edge_end],
        config_features=self.config_feat[config_start:config_end],
        config_runtimes=self.config_runtime[config_start:config_end],
        config_runtime_normalizers=(
            self.config_runtime_normalizers[config_start:config_end]),
        tile_id=self.tile_id[index],
        total_nodes=node_end - node_start,
        total_edges=edge_end - edge_start,
        total_configs=config_end - config_start)

  def get_graph_tensors_dataset(
      self, config_samples: int = -1) -> tf.data.Dataset:
    if self.edge_ranges is None:
      raise ValueError('finalize() was not invoked.')
    dataset = tf.data.Dataset.range(self.edge_ranges.shape[0] - 1)
    dataset = dataset.map(self.get_item, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(
        functools.partial(TileExample.to_graph_tensor,
                          config_samples=config_samples))
    return dataset


def get_npz_split(
    split_path: str, min_configs=2, cache_dir=None) -> NpzDatasetPartition:
  """Returns data for a single partition."""
  glob_pattern = os.path.join(split_path, '*.npz')
  files = tf.io.gfile.glob(glob_pattern)
  if not files:
    raise ValueError('No files matched: ' + glob_pattern)
  if _TOY_DATA.value:
    files = files[:100]

  cache_filename = None
  if cache_dir:
    if not tf.io.gfile.exists(cache_dir):
      tf.io.gfile.makedirs(cache_dir)
    filename_hash = hashlib.md5(
        f'{split_path}:{min_configs}:{_TOY_DATA.value}'.encode()).hexdigest()
    cache_filename = os.path.join(cache_dir, f'{filename_hash}-cache.npz')
    print('dataset cache file: ', cache_filename)

  npz_dataset = NpzDatasetPartition()
  if cache_filename and tf.io.gfile.exists(cache_filename):
    npz_dataset.load_from_file(cache_filename)
  else:
    for filename in tqdm.tqdm(files):
      np_data = np.load(tf.io.gfile.GFile(filename, 'rb'))
      tile_id = os.path.splitext(os.path.basename(filename))[0]
      npz_dataset.add_npz_file(tile_id, np_data, min_configs=min_configs)
    npz_dataset.finalize()
    if cache_filename:
      npz_dataset.save_to_file(cache_filename)

  return npz_dataset


class NpzDataset(NamedTuple):
  """Contains all partitions of the dataset."""
  train: NpzDatasetPartition
  validation: NpzDatasetPartition
  test: NpzDatasetPartition

  @property
  def num_ops(self):
    return int(
        tf.reduce_max([
            tf.reduce_max(self.train.node_opcode),
            tf.reduce_max(self.validation.node_opcode),
            tf.reduce_max(self.test.node_opcode),
        ]).numpy()) + 1

  def _get_normalizer(self, feature_matrix) -> tuple[
      tf.Tensor, tf.Tensor, tf.Tensor]:
    max_feat = tf.reduce_max(feature_matrix, axis=0, keepdims=True)
    min_feat = tf.reduce_min(feature_matrix, axis=0, keepdims=True)
    return min_feat[0] != max_feat[0], min_feat, max_feat

  def _apply_normalizer(self, feature_matrix, used_columns, min_feat, max_feat):
    feature_matrix = tf.boolean_mask(feature_matrix, used_columns, axis=1)
    min_feat = tf.boolean_mask(min_feat, used_columns, axis=1)
    max_feat = tf.boolean_mask(max_feat, used_columns, axis=1)
    return (feature_matrix - min_feat) / (max_feat - min_feat)

  def normalize(self):
    """Removes constant features and normalizes remaining onto [0, 1].

    The statistics are computed only from train partition then applied to all
    partitions {train, test, validation}.
    """
    normalizer_args = self._get_normalizer(self.train.node_feat)
    self.train.node_feat = self._apply_normalizer(
        self.train.node_feat, *normalizer_args)
    self.validation.node_feat = self._apply_normalizer(
        self.validation.node_feat, *normalizer_args)
    self.test.node_feat = self._apply_normalizer(
        self.test.node_feat, *normalizer_args)

    normalizer_args = self._get_normalizer(self.train.config_feat)
    self.train.config_feat = self._apply_normalizer(
        self.train.config_feat, *normalizer_args)
    self.validation.config_feat = self._apply_normalizer(
        self.validation.config_feat, *normalizer_args)
    self.test.config_feat = self._apply_normalizer(
        self.test.config_feat, *normalizer_args)


def get_npz_dataset(
    root_path: str, min_train_configs=-1,
    cache_dir: 'None | str' = None) -> NpzDataset:
  """Returns {train, test, validation} partitions of tiles dataset collection.

  All partitions will be normalized: statistics are computed from training set
  partition and applied to all partitions.

  Args:
    root_path: Path where dataset lives. It must have subdirectories 'train',
      'test' and 'valid'.
    min_train_configs: If > 0, then tile examples will be filtered to have at
      least this many configurations (features and runtimes).
    cache_dir: If given, the many files for each of {train, test, validation}
      will be stored as one file (makes loading faster, for future runs).
  """
  npz_dataset = NpzDataset(
      train=get_npz_split(
          os.path.join(root_path, 'train'), min_configs=min_train_configs,
          cache_dir=cache_dir),
      validation=get_npz_split(
          os.path.join(root_path, 'valid'), cache_dir=cache_dir),
      test=get_npz_split(
          os.path.join(root_path, 'test'), cache_dir=cache_dir))
  npz_dataset.normalize()
  return npz_dataset

In [3]:
###IMPLICIT
import tensorflow as tf
import tensorflow_gnn as tfgnn

EPSILON = 1e-6  # To prevent division by 0.


class Multiplier:
  """Holds an (implicit) matrix that can be multiplied with dense matrices."""
  _transpose: 'Multiplier' = None

  def matmul(self, mat: tf.Tensor) -> tf.Tensor:
    raise NotImplementedError()

  def rmatmul(self, mat: tf.Tensor) -> tf.Tensor:
    raise NotImplementedError()

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    raise NotImplementedError()

  def __matmul__(self, mat: tf.Tensor) -> tf.Tensor:
    tf.assert_equal(self.shape[1], shape(mat)[0])
    return self.matmul(mat)

  def __rmatmul__(self, mat: tf.Tensor) -> tf.Tensor:
    tf.assert_equal(shape(mat)[-1], self.shape[0])
    return self.rmatmul(mat)

  def __add__(self, mat: 'Multiplier') -> 'Multiplier':
    return Sum(self, mat)

  def transpose(self) -> 'Multiplier':
    if self._transpose is None:
      self._transpose = Transpose(self)
    return self._transpose

  def add_eye(self, diag_weight=float(1.0)) -> 'Multiplier':
    tf.assert_equal(self.shape[0], self.shape[1])
    return Sum(self, DiagMatrix(diag_weight * tf.ones([self.shape[0]])))

  def rowsums(self, replace_if_0: 'None|float|tf.Tensor' = None) -> tf.Tensor:
    """Returns vector with shape `num_rows = [self.shape[0]]` that sums rows.

    Args:
      replace_if_0: If None, returns the actual sum, leaving zero-entries as-is.
        Otherwise, zero-entries will be replaced by this value.
    """
    y = self @ tf.ones([self.shape[1]])  # M . 1

    if replace_if_0 is not None:
      y = tf.where(tf.abs(y) < EPSILON, replace_if_0 * tf.ones_like(y), y)
    return y

  def colsums(self, replace_if_0: 'None|float|tf.Tensor' = None) -> tf.Tensor:
    """Returns vector with shape `num_cols = [self.shape[1]]` that sums columns.

    Args:
      replace_if_0: If None, returns the actual sum, leaving zero-entries as-is.
        Otherwise, zero-entries will be replaced by this value.
    """
    y = tf.ones([self.shape[0]]) @ self  # 1^T M  [shape=[cols]]

    if replace_if_0 is not None:
      y = tf.where(tf.abs(y) < EPSILON, replace_if_0 * tf.ones_like(y), y)
    return y

  def normalize_left(self) -> 'Multiplier':
    """Returns a left-stochastic matrix."""
    return Product(self, DiagMatrix(1 / self.colsums(1.0)))

  def normalize_right(self) -> 'Multiplier':
    """Returns a right-stochastic matrix."""
    return Product(DiagMatrix(1 / self.rowsums(1.0)), self)

  def normalize_leftright(self) -> 'Multiplier':
    return Product(
        DiagMatrix(tf.math.rsqrt(self.rowsums(1.0))),
        self,
        DiagMatrix(tf.math.rsqrt(self.colsums(1.0))),
    )

  def normalize_symmetric(self) -> 'Multiplier':
    inv_sqrt_degree = DiagMatrix(tf.math.rsqrt(self.colsums(1.0)))
    return Product(inv_sqrt_degree, self, inv_sqrt_degree)


class Transpose(Multiplier):
  """Defines matrix transpose."""

  def __init__(self, multiplier: Multiplier):
    self._multiplier = multiplier

  def matmul(self, mat: tf.Tensor) -> tf.Tensor:
    # (M'X) == (X'M)'
    return tf.transpose(tf.transpose(mat) @ self._multiplier)

  def rmatmul(self, mat: tf.Tensor) -> tf.Tensor:
    # (XM') == (XM')'' == (M X')'
    return tf.transpose(self._multiplier @ tf.transpose(mat))

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    transpose_shape = self._multiplier.shape
    return (transpose_shape[1], transpose_shape[0])

  def transpose(self) -> Multiplier:
    return self._multiplier


class DiagMatrix(Multiplier):
  """Defines diagonal matrix."""

  def __init__(self, diag_vector: tf.Tensor):
    assert len(diag_vector.shape) == 1, 'Must be a vector.'
    self._diag_vector = diag_vector
    self._vec_shape = shape(diag_vector)[0]

  def matmul(self, mat: tf.Tensor) -> tf.Tensor:
    return tf.einsum('i,i...->i...', self._diag_vector, mat)

  def rmatmul(self, mat: tf.Tensor) -> tf.Tensor:
    return tf.einsum('i,...i->...i', self._diag_vector, mat)

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    return (self._vec_shape, self._vec_shape)


class Product(Multiplier):
  """Defines product of multipliers."""

  def __init__(self, *multipliers: Multiplier):
    assert multipliers
    for i in range(1, len(multipliers)):
      tf.assert_equal(multipliers[i - 1].shape[1], multipliers[i].shape[0])

    self._multipliers = multipliers

  def matmul(self, mat: tf.Tensor) -> tf.Tensor:
    for m in self._multipliers[::-1]:
      mat = m @ mat
    return mat

  def rmatmul(self, mat: tf.Tensor) -> tf.Tensor:
    for m in self._multipliers:
      mat = mat @ m
    return mat

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    return (self._multipliers[0].shape[0], self._multipliers[-1].shape[1])


class Sum(Multiplier):
  """Defines sum of multipliers."""

  def __init__(self, *multipliers: Multiplier):
    assert multipliers
    for i in range(1, len(multipliers)):
      tf.assert_equal(multipliers[i].shape[0], multipliers[0].shape[0])
      tf.assert_equal(multipliers[i].shape[1], multipliers[0].shape[1])
    self._multipliers = multipliers

  def matmul(self, mat: tf.Tensor) -> tf.Tensor:
    return tf.add_n([m @ mat for m in self._multipliers])

  def rmatmul(self, mat: tf.Tensor) -> tf.Tensor:
    return tf.add_n([mat @ m for m in self._multipliers])

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    return self._multipliers[0].shape


class AdjacencyMultiplier(Multiplier):
  r"""Multiplies (sparse) adjacency with dense matrices.

  Yields adjacency with (rows, cols) == (target, source).

  `adj_multiplier @ x` yields tensor `y` with `y[i]` being `\sum_{j->i} x[j]`.

  Init Args:
      graph:
      sender_tag: If `== tfgnn.SOURCE`, then the (implicit) adjacency will be
        of shape `size_target x size_source`. If `== tfgnn.TARGET`, then `shape`
        should be `size_source x size_target`.
  """

  def __init__(
      self, graph, edge_set_name: tfgnn.EdgeSetName,
      edge_weight_feature_name: 'None|tfgnn.FieldName' = None,
      sender_tag: tfgnn.IncidentNodeTag = tfgnn.SOURCE):
    tfgnn.check_scalar_graph_tensor(graph, 'AdjacencyMultiplier')
    self._sender_tag = sender_tag
    self._receiver_tag: tfgnn.IncidentNodeTag = 1 - sender_tag
    self._edge_set_name = edge_set_name
    self._graph = graph
    self._edge_weight_feature_name = edge_weight_feature_name

  @property
  def shape(self) -> tuple['int|tf.Tensor', 'int|tf.Tensor']:
    """Shape is (size of receiver node set, size of sender node set)."""
    adj = self._graph.edge_sets[self._edge_set_name].adjacency
    sender_node_set_name = adj.node_set_name(self._sender_tag)
    receiver_node_set_name = adj.node_set_name(self._receiver_tag)
    sender_sizes = self._graph.node_sets[sender_node_set_name].sizes
    receiver_sizes = self._graph.node_sets[receiver_node_set_name].sizes
    return (tf.cast(tf.reduce_sum(receiver_sizes), tf.int32),
            tf.cast(tf.reduce_sum(sender_sizes), tf.int32))

  def matmul(self, mat: tf.Tensor):
    edge_level = tfgnn.broadcast_node_to_edges(
        self._graph, self._edge_set_name, self._sender_tag, feature_value=mat)

    if self._edge_weight_feature_name:
      edge_set = self._graph.edge_sets[self._edge_set_name]
      edge_level *= edge_set[self._edge_weight_feature_name]

    return tfgnn.pool_edges_to_node(
        self._graph, self._edge_set_name, self._receiver_tag,
        feature_value=edge_level)

  def rmatmul(self, mat):
    edge_level = tfgnn.broadcast_node_to_edges(
        self._graph, self._edge_set_name, self._receiver_tag,
        feature_value=tf.transpose(mat))

    if self._edge_weight_feature_name:
      edge_set = self._graph.edge_sets[self._edge_set_name]
      edge_level *= edge_set[self._edge_weight_feature_name]

    return tf.transpose(tfgnn.pool_edges_to_node(
        self._graph, self._edge_set_name, self._sender_tag,
        feature_value=edge_level))


def shape(tensor: tf.Tensor) -> 'list[int]|tf.Tensor':
  """Helper function returns shape of eager or symbolic tensors."""
  if any([s is None for s in tensor.shape]):
    return tf.shape(tensor)
  else:
    return tensor.shape

In [4]:
###TRAINING
# Install standard modules
class _OpEmbedding(tf.keras.Model):
  """Embeds GraphTensor.node_sets['op']['op'] nodes into feature 'op_e'."""

  def __init__(self, num_ops: int, embed_d: int, l2reg: float = 1e-4):
    super().__init__()
    self.embedding_layer = tf.keras.layers.Embedding(
        num_ops, embed_d, activity_regularizer=tf.keras.regularizers.l2(l2reg))

  def call(
      self, graph: tfgnn.GraphTensor,
      training: bool = False) -> tfgnn.GraphTensor:
    op_features = dict(graph.node_sets['op'].features)
    op_features['op_e'] = self.embedding_layer(
        tf.cast(graph.node_sets['op']['op'], tf.int32))
    return graph.replace_features(node_sets={'op': op_features})


def pair_layout_graph_with_label(graph: tfgnn.GraphTensor):
    """Extracts label from graph (`tfgnn.GraphTensor`) and returns a pair of `(graph, label)`"""
    # Return runtimes divded over large number: only ranking is required. The
    # runtimes are in the 100K range
    label = tf.cast(graph.node_sets['g']['runtimes'], tf.float32) / 1e7
    return graph, label



class ResModel(tf.keras.Model):
    """GNN with residual connections."""

    def __init__(
        self, num_configs: int, num_ops: int, op_embed_dim: int = 32,
        num_gnns: int = 2, mlp_layers: int = 2,
        hidden_activation: str = 'leaky_relu',
        hidden_dim: int = 32, reduction: str = 'sum'):
        super().__init__()
        self._num_configs = num_configs
        self._num_ops = num_ops
        self._op_embedding = _OpEmbedding(num_ops, op_embed_dim)
        self._prenet = _mlp([hidden_dim] * mlp_layers, hidden_activation)
        self._gc_layers = []
        for _ in range(num_gnns):
            self._gc_layers.append(_mlp([hidden_dim] * mlp_layers, hidden_activation))
        self._postnet = _mlp([hidden_dim, 1], hidden_activation, use_bias=False)

    def call(self, graph: tfgnn.GraphTensor, training: bool = False):
        del training
        return self.forward(graph, self._num_configs)

    def _node_level_forward(
        self, node_features: tf.Tensor,
        config_features: tf.Tensor,
        graph: tfgnn.GraphTensor, num_configs: int,
        edgeset_prefix='') -> tf.Tensor:
        """implements the full computation within a GNN layer:
        obtains adjacency Matrices and normalizes them, 
        transforms and normalizes nodes and configuration.
        applies the Pre-processing MLP and performs the Graph Convolution Operation.
        """
    
        adj_op_op = AdjacencyMultiplier(
            graph, edgeset_prefix+'feed')  # op->op
        adj_config = AdjacencyMultiplier(
            graph, edgeset_prefix+'config')  # nconfig->op

        adj_op_op_hat = (adj_op_op + adj_op_op.transpose()).add_eye()
        adj_op_op_hat = adj_op_op_hat.normalize_symmetric()

        x = node_features

        x = tf.stack([x] * num_configs, axis=1)
        config_features = 100 * (adj_config @ config_features)
        x = tf.concat([config_features, x], axis=-1)
        x = self._prenet(x)
        x = tf.nn.leaky_relu(x)

        for layer in self._gc_layers:
            y = x
            y = tf.concat([config_features, y], axis=-1)
            y = tf.nn.leaky_relu(layer(adj_op_op_hat @ y))
            x += y
        return x

    def forward(
        self, graph: tfgnn.GraphTensor, num_configs: int,
        backprop=True) -> tf.Tensor:
        """
        Overall forward pass within the embedding layer,
        the node-level forward pass (_node_level_forward),
        and the final global pooling and post-processing stages.
        """
        graph = self._op_embedding(graph)

        config_features = graph.node_sets['nconfig']['feats']
        node_features = tf.concat([
            graph.node_sets['op']['feats'],
            graph.node_sets['op']['op_e']
        ], axis=-1)

        x_full = self._node_level_forward(
            node_features=tf.stop_gradient(node_features),
            config_features=tf.stop_gradient(config_features),
            graph=graph, num_configs=num_configs)

        if backprop:
            x_backprop = self._node_level_forward(
                node_features=node_features,
                config_features=config_features,
                graph=graph, num_configs=num_configs, edgeset_prefix='sampled_')

            is_selected = graph.node_sets['op']['selected']
            # Need to expand twice as `is_selected` is a vector (num_nodes) but
            # x_{backprop, full} are 3D tensors (num_nodes, num_configs, num_feats).
            is_selected = tf.expand_dims(is_selected, axis=-1)
            is_selected = tf.expand_dims(is_selected, axis=-1)
            x = tf.where(is_selected, x_backprop, x_full)
        else:
            x = x_full

        adj_config = AdjacencyMultiplier(graph, 'config')

        # Features for configurable nodes.
        config_feats = (adj_config.transpose() @ x)

        # Global pooling
        adj_pool_op_sum = AdjacencyMultiplier(graph, 'g_op').transpose()
        adj_pool_op_mean = adj_pool_op_sum.normalize_right()
        adj_pool_config_sum = AdjacencyMultiplier(
            graph, 'g_config').transpose()
        x = self._postnet(tf.concat([
            # (A D^-1) @ Features
            adj_pool_op_mean @ x,
            # l2_normalize( A @ Features )
            tf.nn.l2_normalize(adj_pool_op_sum @ x, axis=-1),
            # l2_normalize( A @ Features )
            tf.nn.l2_normalize(adj_pool_config_sum @ config_feats, axis=-1),
        ], axis=-1))

        x = tf.squeeze(x, -1)

        return x

def _mlp(dims, hidden_activation, l2reg=1e-4, use_bias=True):
  """Helper function for multi-layer perceptron (MLP)."""
  layers = []
  for i, dim in enumerate(dims):
    if i > 0:
      layers.append(tf.keras.layers.Activation(hidden_activation))
    layers.append(tf.keras.layers.Dense(
        dim, kernel_regularizer=tf.keras.regularizers.l2(l2reg),
        use_bias=use_bias))
  return tf.keras.Sequential(layers)

"""
CREATE DATASETS FOR TRAINING
"""

def pull_data(CONFIGS_PER_GRAPH, MAX_TRAIN_CONFIGS, MAX_NUM_CONFIGS, MAX_KEEP_NODES, BATCH_SIZE, layout_data_root_dir):
  layout_npz_dataset = get_npz_dataset(
      layout_data_root_dir,
      min_train_configs=CONFIGS_PER_GRAPH,
      max_train_configs= MAX_NUM_CONFIGS,  # If any graph has more than this configurations, it will be filtered [speeds up loading + training]
      cache_dir='cache'
  )

  layout_train_ds = (
        layout_npz_dataset.train.get_graph_tensors_dataset(
            config_samples = CONFIGS_PER_GRAPH,
            max_nodes=MAX_KEEP_NODES)
        .shuffle(100, reshuffle_each_iteration=True)
        .batch(BATCH_SIZE, drop_remainder=False)
        .map(tfgnn.GraphTensor.merge_batch_to_components)
        .map(pair_layout_graph_with_label))

  layout_valid_ds = (
        layout_npz_dataset.validation.get_graph_tensors_dataset(
            config_samples = CONFIGS_PER_GRAPH)
        .batch(BATCH_SIZE, drop_remainder=False)
        .map(tfgnn.GraphTensor.merge_batch_to_components)
        .map(pair_layout_graph_with_label))

  return layout_npz_dataset, layout_train_ds, layout_valid_ds



def create_model(CONFIGS_PER_GRAPH, layout_npz_dataset):
  model = ResModel(CONFIGS_PER_GRAPH, layout_npz_dataset.num_ops)

  loss = tfr.keras.losses.ListMLELoss()  # (temperature=10)
  opt = tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=0.5)

  model.compile(loss=loss, optimizer=opt, metrics=[
      tfr.keras.metrics.OPAMetric(name='opa_metric'),
  ])

  return model

def train_model(model, epochs, layout_train_ds, layout_valid_ds):
  best_val_opa = -1  # Tracks best validation OPA
  best_val_at_epoch = -1  # At which epoch.

  for i in range(epochs):
      history = model.fit(
          layout_train_ds, epochs=1, verbose=1, validation_data=layout_valid_ds,
          validation_freq=1)

      train_loss = history.history['loss'][-1]
      train_opa = history.history['opa_metric'][-1]
      val_loss = history.history['val_loss'][-1]
      val_opa = history.history['val_opa_metric'][-1]
      if val_opa > best_val_opa:
          best_val_opa = val_opa
          best_val_at_epoch = i
          best_params = {v.ref: v + 0 for v in model.trainable_variables}
          print(' * [@%i] Validation (NEW BEST): %s' % (i, str(val_opa)))
      elif early_stop > 0 and i - best_val_at_epoch >= early_stop:
        print('[@%i] Best accuracy was attained at epoch %i. Stopping.' % (i, best_val_at_epoch))
        break
  # Restore best parameters.
  print('Restoring parameters corresponding to the best validation OPA.')
  assert best_params is not None
  for v in model.trainable_variables:
      v.assign(best_params[v.ref])

  return model, train_loss, train_opa, val_loss, val_opa, best_params


def run_inference(model, _INFERENCE_CONFIGS_BATCH_SIZE, layout_npz_dataset):
  print('\n\n   Running inference on test set ...\n\n')
  test_rankings = []

  assert layout_npz_dataset.test.graph_id is not None
  for graph in tqdm.tqdm(layout_npz_dataset.test.iter_graph_tensors(),
                        total=layout_npz_dataset.test.graph_id.shape[-1],
                        desc='Inference'):
      # print(graph)
      num_configs = graph.node_sets['g']['runtimes'].shape[-1]
      # print(num_configs)
      # print(MAX_KEEP_NODES)
      # print("\n\n\n")
      all_scores = []
      for i in tqdm.tqdm(range(0, num_configs, _INFERENCE_CONFIGS_BATCH_SIZE)):
          end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, num_configs)
          # Take a cut of the configs.
          node_set_g = graph.node_sets['g']
          subconfigs_graph = tfgnn.GraphTensor.from_pieces(
              edge_sets=graph.edge_sets,
              node_sets={
                  'op': graph.node_sets['op'],
                  'nconfig': tfgnn.NodeSet.from_fields(
                      sizes=graph.node_sets['nconfig'].sizes,
                      features={
                          'feats': graph.node_sets['nconfig']['feats'][:, i:end_i],
                      }),
                  'g': tfgnn.NodeSet.from_fields(
                      sizes=tf.constant([1]),
                      features={
                          'graph_id': node_set_g['graph_id'],
                          'runtimes': node_set_g['runtimes'][:, i:end_i],
                          'kept_node_ratio': node_set_g['kept_node_ratio'],
                      })
              })
          h = model.forward(subconfigs_graph, num_configs=(end_i - i),
                            backprop=False)
          all_scores.append(h[0])
      all_scores = tf.concat(all_scores, axis=0)
      graph_id = graph.node_sets['g']['graph_id'][0].numpy().decode()
      sorted_indices = tf.strings.join(
          tf.strings.as_string(tf.argsort(all_scores)), ';').numpy().decode()
      test_rankings.append((graph_id, sorted_indices))
  return test_rankings

def write_output(test_rankings, output_csv_filename, SOURCE, SEARCH):
    with tf.io.gfile.GFile(output_csv_filename, 'w') as fout:
        fout.write('ID,TopConfigs\n')
        for graph_id, ranks in test_rankings:
            fout.write(f'layout:{SOURCE}:{SEARCH}:{graph_id},{ranks}\n')
    print('\n\n   ***  Wrote', output_csv_filename, '\n\n')

"""
BEGIN RUNNING CODE!!!
THERE ARE SETTINGS AND HYPERPARAMETERES
"""

def main(source, search, **kwargs):
  # need to download npz
  # tile_data_ROOT = '/npz/layout'
  tile_data_ROOT = '/content/drive/MyDrive/npz 2/layout'
  SOURCE = source  # Can be "xla" or "nlp"
  SEARCH = search  # Can be "random" or "default"

  tile_data_root_dir = os.path.join(
        os.path.expanduser(tile_data_ROOT), SOURCE, SEARCH)

  # Batch size information.
  # BATCH_SIZE = 10  # Number of graphs per batch.
  # CONFIGS_PER_GRAPH = 2  # Number of configurations (features and target values) per graph.
  # MAX_NUM_CONFIGS = 20 # maximum number of configurations to filter for
  # MAX_KEEP_NODES = 100  # Useful for dropout.
  # MAX_TRAIN_CONFIGS = 20

  BATCH_SIZE = kwargs['batch_size']  # Number of graphs per batch.
  CONFIGS_PER_GRAPH = kwargs['configs_per_graph']  # Number of configurations (features and target values) per graph.
  MAX_NUM_CONFIGS = kwargs['max_num_configs'] # maximum number of configurations to filter for
  MAX_KEEP_NODES = kwargs['max_keep_nodes']  # Useful for dropout.
  MAX_TRAIN_CONFIGS = kwargs['max_train_configs']

  # edges "sampled_config" and "sampled_feed" (or, "con50fig" and "feed")
  early_stop = 5  # If validation OPA did not increase in this many epochs, terminate training.
  best_params = None  # Stores parameters corresponding to best validation OPA, to restore to them after training.
  epochs = 1  # Total number of training epochs.

  # pull the data
  
  tiles_npz_dataset, tile_train_ds, tile_valid_ds = pull_data(CONFIGS_PER_GRAPH, MAX_KEEP_NODES, BATCH_SIZE, tile_data_root_dir)
  model = create_model(CONFIGS_PER_GRAPH, tiles_npz_dataset)
  model, train_loss, train_opa, val_loss, val_opa, best_params = train_model(model, epochs, tile_train_ds, tile_valid_ds)


  _INFERENCE_CONFIGS_BATCH_SIZE = 50
  # _INFERENCE_CONFIGS_BATCH_SIZE = 100

  folder_path = '/content/drive/MyDrive/tpu_graphs/outputcsvs/'
  output_csv_filename = f'inference_layout_{SOURCE}_{SEARCH}.csv'
  output_csv_filename = folder_path + output_csv_filename

  test_rankings = run_inference(model, _INFERENCE_CONFIGS_BATCH_SIZE, tiles_npz_dataset)
  write_output(test_rankings, output_csv_filename, SOURCE, SEARCH)


# Params to Configure

For every Kaggle submission, the same hyperparameters need to be used four times, in order to produce 4 output CSVs. The 4 configurations are:
1. source = nlp; search = random
2. source = nlp; search = default
3. source = xla; search = random
4. source = xla; search = default



In [5]:
source = 'xla' # Has to be be "xla"
batch_size = 10  # Number of graphs per batch.
configs_per_graph = 2  # Number of configurations (features and target values) per graph.
max_num_configs = 20 # maximum number of configurations to filter for
max_keep_nodes = 100  # Useful for dropout.
max_train_configs = 20 # If any graph has more than this configurations, it will be filtered [speeds up loading + training]

In [None]:
main(source = source,
     search = '',
      batch_size = batch_size,
     configs_per_graph = configs_per_graph,
     max_num_configs = max_num_configs,
     max_keep_nodes = max_keep_nodes,
     max_train_configs = max_train_configs
     )

AAA
dataset cache file:  cache/da030f462f243e051c33e7eb1886d667-cache.npz
loaded from cache/da030f462f243e051c33e7eb1886d667-cache.npz
dataset cache file:  cache/4378df730c04529a89e472ddf67e9460-cache.npz
loaded from cache/4378df730c04529a89e472ddf67e9460-cache.npz
dataset cache file:  cache/66a2b3a0bac484e0aec26d531f74b257-cache.npz
loaded from cache/66a2b3a0bac484e0aec26d531f74b257-cache.npz
 * [@0] Validation (NEW BEST): 0.6000000238418579
Restoring parameters corresponding to the best validation OPA.


   Running inference on test set ...




Inference:   0%|          | 0/17 [00:00<?, ?it/s]
  0%|          | 0/20 [00:00<?, ?it/s][A
  5%|▌         | 1/20 [00:01<00:36,  1.90s/it][A
 10%|█         | 2/20 [00:03<00:32,  1.82s/it][A
 15%|█▌        | 3/20 [00:05<00:30,  1.80s/it][A
 20%|██        | 4/20 [00:07<00:28,  1.79s/it][A
 25%|██▌       | 5/20 [00:09<00:26,  1.79s/it][A
 30%|███       | 6/20 [00:10<00:25,  1.79s/it][A
 35%|███▌      | 7/20 [00:12<00:23,  1.78s/it][A
 40%|████      | 8/20 [00:14<00:21,  1.78s/it][A
 45%|████▌     | 9/20 [00:16<00:19,  1.77s/it][A
 50%|█████     | 10/20 [00:17<00:17,  1.78s/it][A
 55%|█████▌    | 11/20 [00:19<00:15,  1.77s/it][A
 60%|██████    | 12/20 [00:21<00:14,  1.77s/it][A
 65%|██████▌   | 13/20 [00:23<00:12,  1.77s/it][A
 70%|███████   | 14/20 [00:24<00:10,  1.77s/it][A
 75%|███████▌  | 15/20 [00:26<00:08,  1.78s/it][A
 80%|████████  | 16/20 [00:28<00:07,  1.78s/it][A
 85%|████████▌ | 17/20 [00:30<00:05,  1.77s/it][A
 90%|█████████ | 18/20 [00:32<00:03,  1.77s/it][A