#### 1. Setup: Install and Import libraries

First we need to install the libraries we will be using for this tutorial.

Next we import these libraries. We are going to use PyTorch Geometrics Data, Dataset and DataLoader objects to process our data. We will then convert the dataset to a Jraph GraphsTuple.

In [1]:
# Imports
import functools
import jax
import jax.numpy as jnp
import time
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as np
import torch

from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset

from flax import linen as nn
from flax.training import train_state
import pathlib
import csv
import time
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


#### 2. Download dataset from PyTorch Geometrics Common Benchmark Datasets.

In [None]:
# from ogb.nodeproppred import NodePropPredDataset

# d_name = 'ogbn-proteins'

# dataset = NodePropPredDataset(name = d_name)

# split_idx = dataset.get_idx_split()
# train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
# graph, label = dataset[0] # graph: library-agnostic graph object


In [2]:
from ogb.nodeproppred import PygNodePropPredDataset

dataset = PygNodePropPredDataset(name = 'ogbn-proteins')

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]


data = dataset[0] # pyg graph object



Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:59<00:00,  3.62it/s]


Extracting dataset/proteins.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 2935.13it/s]


Saving...


Done!


In [3]:
dataset = PygNodePropPredDataset(name = 'ogbn-proteins', root='/data101/makinen/ogbn/')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)

In [4]:
data

Data(num_nodes=132534, edge_index=[2, 79122504], edge_attr=[79122504, 8], y=[132534, 112])

In [5]:
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.utils import scatter


row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_reader = RandomNodeLoader(data, num_parts=200, shuffle=True,
                                num_workers=0)
test_reader = RandomNodeLoader(data, num_parts=5, num_workers=0)

In [6]:
for d in train_reader:
  print(d.x.shape)
  print(d.train_mask.sum())

torch.Size([663, 8])
tensor(430)
torch.Size([663, 8])
tensor(432)
torch.Size([663, 8])
tensor(444)


KeyboardInterrupt: 

Next we split the toy data and apply it to our Custom dataset class. The class returns one graph sample as a Pytorch Geometric Data object.

In [13]:
data.x.shape

torch.Size([132534, 8])

In [None]:
132534 / 3314

39.99215449607725

#### 4. [Optional]: Add Graph Padding to Speed Up Training.

As Jax recompiles the program for each graph size, padding the number of nodes and edges in the graph to the nearest power of two can speed up training. See this [tutorial](https://colab.research.google.com/github/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb#scrollTo=lGhnsIovZQpo) for more details.

In [7]:
# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/ogb_examples/train.py
def _nearest_bigger_power_of_two(x: int) -> int:
    """Computes the nearest power of two greater than x for padding."""
    y = 2
    while y < x:
        y *= 2
    return y

def pad_graph_to_nearest_power_of_two(graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Pads a batched `GraphsTuple` to the nearest power of two.

    For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
    would pad the `GraphsTuple` nodes and edges:
        7batch_sizedes --> 8 nodes (2^3)
        5 edges --> 8 edges (2^3)

    And since padding is accomplished using `jraph.pad_with_graphs`, an extra
    graph and node is added:
        8 nodes --> 9 nodes
        3 graphs --> 4 graphs

    Args:
        graphs_tuple: a batched `GraphsTuple` (can be batch size 1).

    Returns:
        A graphs_tuple batched to the nearest power of two.
    """
    # Add 1 since we need at least one padding node for pad_with_graphs.
    pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1
    pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge))
    # Add 1 since we need at least one padding graph for pad_with_graphs.
    # We do not pad to nearest power of two because the batch size is fixed.
    pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
    return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)


#### 5. Convert the PyTorch Geometric Object to a Jraph GraphsTuple

[link text](https://)Finally we will convert the PyTorch Geometric Data object to a Jraph GraphsTuple. The function written below should be called in your training loop after loading the batch data. The function also pads the whole batch graphs once and returns padded batch GraphsTuple.


# SET THE MODE TO TRAIN OR VALID !!

In [7]:
for i,d in enumerate(train_reader):
  print(d)

Data(num_nodes=884, edge_index=[2, 3824], edge_attr=[3824, 8], y=[884, 112], x=[884, 8], train_mask=[884], valid_mask=[884], test_mask=[884])
Data(num_nodes=884, edge_index=[2, 3094], edge_attr=[3094, 8], y=[884, 112], x=[884, 8], train_mask=[884], valid_mask=[884], test_mask=[884])


KeyboardInterrupt: 

In [8]:
import jax.numpy as jnp
import numpy as np

In [9]:
def get_batched_padded_graph_tuples(batch, mode="train", pad=False):

    masks = (jnp.array(batch.train_mask), jnp.array(batch.valid_mask), jnp.array(batch.test_mask))

    if mode == "train":
      mask = masks[0]
    elif mode == "valid":
      mask = masks[1]
    else:
      mask = masks[2]


    graphs = jraph.GraphsTuple(
            nodes=jnp.array(batch.x),
            edges=jnp.array(batch.edge_attr), # this particular data doesn't have edge features, hence we set to None
            n_node=jnp.array([batch.num_nodes]),
            n_edge=jnp.array([batch.edge_attr.shape[0]]),
            senders=jnp.array(batch.edge_index[0, :]),
            receivers=jnp.array(batch.edge_index[1, :]),
            globals=None)

    labels = jnp.array(batch.y)
    if pad:
        graphs = pad_graph_to_nearest_power_of_two(graphs) # padd the whole batch once
    
        # put into jnp
        graphs = jraph.GraphsTuple(
                nodes=jnp.array(graphs.nodes),
                edges=jnp.array(graphs.edges), 
                n_node=jnp.array(graphs.n_node),
                n_edge=jnp.array(graphs.n_edge),
                senders=jnp.array(graphs.senders),
                receivers=jnp.array(graphs.receivers),
                globals=None)
    
    return graphs, labels, masks

#### Benchmarking the DataLoader

In [10]:
%%time
batch = next(iter(train_reader))

CPU times: user 10.4 s, sys: 4.81 s, total: 15.2 s
Wall time: 1.94 s


In [11]:
%%time
graphs, labels, masks = get_batched_padded_graph_tuples(batch)

CPU times: user 228 ms, sys: 327 ms, total: 555 ms
Wall time: 427 ms


In [12]:
graphs.nodes

Array([[1.0903156e+01, 4.0299836e-01, 4.0299836e-01, ..., 3.4755978e+01,
        4.0299836e-01, 3.8805943e+01],
       [5.7158399e+00, 2.2160218e+00, 2.3250105e+00, ..., 4.1051447e+02,
        2.3442184e+02, 1.6764810e+02],
       [1.2919567e+01, 1.4190103e+00, 1.4190103e+00, ..., 2.0222116e+02,
        9.2318535e+01, 1.8551103e+02],
       ...,
       [3.0109994e+00, 1.1000001e-02, 1.1000001e-02, ..., 8.1799996e-01,
        1.1000001e-02, 2.8800002e-01],
       [1.3528005e+01, 2.8000003e-02, 2.8000003e-02, ..., 2.0300001e-01,
        2.8000003e-02, 2.8000003e-02],
       [4.5099998e+00, 1.0000001e-02, 1.0000001e-02, ..., 1.0000001e-02,
        1.0000001e-02, 1.6000000e-01]], dtype=float32)

In [13]:
print('graphs: ', graphs)

graphs:  GraphsTuple(nodes=Array([[1.0903156e+01, 4.0299836e-01, 4.0299836e-01, ..., 3.4755978e+01,
        4.0299836e-01, 3.8805943e+01],
       [5.7158399e+00, 2.2160218e+00, 2.3250105e+00, ..., 4.1051447e+02,
        2.3442184e+02, 1.6764810e+02],
       [1.2919567e+01, 1.4190103e+00, 1.4190103e+00, ..., 2.0222116e+02,
        9.2318535e+01, 1.8551103e+02],
       ...,
       [3.0109994e+00, 1.1000001e-02, 1.1000001e-02, ..., 8.1799996e-01,
        1.1000001e-02, 2.8800002e-01],
       [1.3528005e+01, 2.8000003e-02, 2.8000003e-02, ..., 2.0300001e-01,
        2.8000003e-02, 2.8000003e-02],
       [4.5099998e+00, 1.0000001e-02, 1.0000001e-02, ..., 1.0000001e-02,
        1.0000001e-02, 1.6000000e-01]], dtype=float32), edges=Array([[0.501, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
       [0.501, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
       [0.501, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
       ...,
       [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.427],
       [0.001, 0.001, 0.001, .

In [13]:
labels

Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [14]:
graphs.edges.shape, graphs.nodes.shape

((51662, 8), (3314, 8))

In [15]:
graphs.nodes

Array([[3.8999990e-02, 3.8999990e-02, 3.8999990e-02, ..., 3.8999990e-02,
        3.8999990e-02, 3.8999990e-02],
       [1.9265053e+01, 6.6917572e+01, 1.7660265e+00, ..., 1.8616872e+02,
        1.7660265e+00, 1.5972354e+02],
       [5.8999884e-01, 9.0000004e-02, 9.0000004e-02, ..., 5.7219958e+00,
        9.0000004e-02, 2.2717001e+01],
       ...,
       [5.0139995e+00, 1.4000001e-02, 1.4000001e-02, ..., 6.3400006e-01,
        1.4000001e-02, 1.4000001e-02],
       [1.0027004e+01, 2.7000003e-02, 2.7000003e-02, ..., 1.1300000e+00,
        2.7000003e-02, 2.7999997e-01],
       [7.0289984e+00, 2.9000003e-02, 2.9000003e-02, ..., 2.9000003e-02,
        1.3529000e+01, 2.9000003e-02]], dtype=float32)

In [16]:
print('labels: ', labels)

labels:  [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


# do the graphnet thing

In [14]:
class MLP(hk.Module):
  def __init__(self, features: jnp.ndarray):
    super().__init__()
    self.features = features

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    layers = []
    for feat in self.features[:-1]:
      layers.append(hk.Linear(feat))
      layers.append(jax.nn.relu)
    layers.append(hk.Linear(self.features[-1]))

    mlp = hk.Sequential(layers)
    return mlp(x)

# Use MLP block to define the update node function
update_node_fn = lambda x: MLP(features=[8, 4])(x)

In [15]:
import tensorflow_probability.substrates.jax as tfp

# custom fishnets aggregation

def fill_diagonal(a, val):
  assert a.ndim >= 2
  i, j = jnp.diag_indices(min(a.shape[-2:]))
  return a.at[..., i, j].set(val)

def construct_fisher_matrix_multiple(outputs):
    Q = (tfp.math.fill_triangular(outputs))
    # vmap the jnp.diag function for the batch
    _diag = jax.vmap(jnp.diag)
    middle = _diag(jnp.triu(Q) - jax.nn.softplus(jnp.triu(Q)))
    padding = jnp.zeros(Q.shape)

    L = Q - fill_diagonal(padding, middle)

    return jnp.einsum('...ij,...jk->...ik', L, jnp.transpose(L, (0, 2, 1)))

def construct_fisher_matrix_single(outputs):
    Q = tfp.math.fill_triangular(outputs)
    middle = jnp.diag(jnp.triu(Q) - jax.nn.softplus(jnp.triu(Q)))
    padding = jnp.zeros(Q.shape)

    L = Q - fill_diagonal(padding, middle)

    return jnp.einsum('...ij,...jk->...ik', L, jnp.transpose(L, (1, 0)))



In [16]:
from networkx.algorithms.clique import graph_clique_number
#from sklearn.utils.validation import FiniteStatus

import jax.tree_util as tree
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as onp
import networkx as nx
from typing import Any, Callable, Dict, List, Optional, Tuple

# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/ogb_examples/train.py

n_p = 8 # bottleneck
n_fisher = (n_p * (n_p + 1)) // 2

n_bottleneck = n_p + n_fisher

n_bottleneck_nodes = 112

hidden_size = 50

# custom mean function for padded inputs
def fishnets_aggregation(
                 data: jnp.ndarray,
                 segment_ids: jnp.ndarray,
                 num_segments: Optional[int] = None,
                 indices_are_sorted: bool = False,
                 unique_indices: bool = False):
  """Returns mean for each segment.
  Args:
    n_data: the number of data we want to take the mean of
    data: the values which are averaged segment-wise.
    segment_ids: indices for the segments.
    num_segments: total number of segments.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether ``segment_ids`` is known to be free of duplicates.
  """
  #print("data", data.shape)
  score = data[..., :n_p]
  fisher = data[..., n_p:]

  # print("fisher cholesky", fisher.shape)
  # print("score", score.shape)
  # print("segment_ids", segment_ids.shape)

  score = jraph.segment_sum(
      score,
      segment_ids,
      num_segments,
      indices_are_sorted=indices_are_sorted,
      unique_indices=unique_indices)

  # construct fisher matrix
  fisher = construct_fisher_matrix_multiple(fisher)

  # should construct matrix before doing sum but let's see how this works
  fisher = jraph.segment_sum(
      fisher.reshape(-1, int(n_p**2)),
      segment_ids,
      num_segments,
      indices_are_sorted=indices_are_sorted,
      unique_indices=unique_indices).reshape(-1, n_p, n_p)

  fisher += jnp.eye(n_p) # add prior
  mle = jnp.einsum('...jk,...k->...j', jnp.linalg.inv(fisher), score)

  return mle



def fishnets_for_edges(
                 data: jnp.ndarray,
                 segment_ids: jnp.ndarray,
                 num_segments: Optional[int] = None,
                 indices_are_sorted: bool = False,
                 unique_indices: bool = False):
  """Returns mean for each segment.
  Args:
    n_data: the number of data we want to take the mean of
    data: the values which are averaged segment-wise.
    segment_ids: indices for the segments.
    num_segments: total number of segments.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether ``segment_ids`` is known to be free of duplicates.
  """
  #print("data", data.shape)
  score = data[..., :n_p]
  fisher = data[..., n_p:]

  score = jraph.segment_sum(
      score,
      segment_ids,
      num_segments,
      indices_are_sorted=indices_are_sorted,
      unique_indices=unique_indices)

  # construct fisher matrix
  fisher = construct_fisher_matrix_multiple(fisher)

  # should construct matrix before doing sum but let's see how this works
  fisher = jraph.segment_sum(
      fisher.reshape(-1, int(n_p**2)),
      segment_ids,
      num_segments,
      indices_are_sorted=indices_are_sorted,
      unique_indices=unique_indices).reshape(-1, n_p, n_p)

  fisher += jnp.eye(n_p) # add prior

  mle = jnp.einsum('...jk,...k->...j', jnp.linalg.inv(fisher), score)

  # concatenate fisher and output mle for loss function
  output = jnp.concatenate([mle, fisher.reshape(-1, n_p**2)], 1)

 # print("edge aggregation output shape", output.shape)

  return output


@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Edge update function for graph net."""
  net = hk.Sequential(
      [hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(n_bottleneck)])
  return net(feats)

# make this into a fishnets aggregation

@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""

  #print('input to node fn', feats.shape)

  net = hk.Sequential(
      [hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(n_bottleneck_nodes)])
  return net(feats)

@jraph.concatenated_args
def update_global_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Global update function for graph net."""
  # MUTAG is a binary classification task, so output pos neg logits.
  #print("global feats", feats.shape)
  net = hk.Sequential(
      [hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(2)])
  return net(feats)

@jraph.concatenated_args
def node_fishnets_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""
 #print('input to node fn', feats.shape)

  # concatenated input ordering:
  # concatenate([node, edge_sent, edge_received])

  edge_feats = feats[..., n_bottleneck_nodes:]

  edge_feats_receivers = edge_feats[..., :(n_p + n_p**2)]
  edge_feats_senders =  edge_feats[..., (n_p + n_p**2):]
  edge_feats = edge_feats_senders + edge_feats_receivers

  return edge_feats


def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:


  # Add a global paramater for graph classification.
  graph = graph._replace(globals=None,
                         #jnp.zeros([graph.n_node.shape[0], 1]),
                         nodes=graph.nodes[..., n_p:],
                         )

  embedder = jraph.GraphMapFeatures(
    embed_edge_fn=hk.Linear(hidden_size), embed_node_fn=hk.Sequential(
      [hk.Linear(hidden_size), jax.nn.swish,
       hk.Linear(n_bottleneck_nodes)])
    )

  graph = embedder(graph)

  #print("pre-gnn nodes", graph.nodes.shape)
  #print("pre-gnn globals", graph.globals)

    # rho aggregation functions
  #aggregate_nodes_for_globals_fn = lambda d,s,n: node_fishnets_fn(d,s,n) # jnp.arcsinh(jraph.segment_sum(d,s,n)) #lambda d,s,n:
  #aggregate_edges_for_globals_fn = lambda d,s,n: fishnets_aggregation(d,s,n) # jnp.arcsinh(jraph.segment_sum(d,s,n)) #
  aggregate_edges_for_nodes_fn = lambda d,s,n: fishnets_for_edges(d,s,n)  #lambda d,s,n: fishnets_aggregation(d,s,n) #


  net = jraph.InteractionNetwork(
      update_node_fn=node_update_fn,
      update_edge_fn=edge_update_fn,
      aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
      include_sent_messages_in_node_update=True
      )

  graph = net(graph)

  #print("edges", graph.edges.shape)
  #print("nodes", graph.nodes.shape)


  return graph

In [17]:
n_p, n_fisher, n_bottleneck

(8, 36, 44)

In [18]:
net = hk.without_apply_rng(hk.transform(net_fn))
# Get a candidate graph and label to initialize the network.
graph, labels, masks = get_batched_padded_graph_tuples(batch)

# Initialize the network.
params = net.init(jax.random.PRNGKey(42), graph)

  stddev = 1. / np.sqrt(self.input_size)


In [19]:
outgraph = net.apply(params, graph)

In [20]:
graph.nodes.shape

(663, 8)

In [21]:
outgraph.nodes.shape, outgraph.edges.shape

((663, 112), (2448, 44))

In [22]:
from sklearn.metrics import roc_auc_score

def compute_bce_with_logits_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes binary cross-entropy with logits loss.

  Combines sigmoid and BCE, and uses log-sum-exp trick for numerical stability.
  See https://stackoverflow.com/a/66909858 if you want to learn more.

  Args:
    x: Predictions (logits).
    y: Labels.

  Returns:
    Binary cross-entropy loss with mean aggregation.

  """
  max_val = jnp.clip(x, 0, None)
  loss = x - x * y + max_val + jnp.log(
      jnp.exp(-max_val) + jnp.exp((-x - max_val)))
  return loss.mean()


def compute_loss_net(params: hk.Params, graph: jraph.GraphsTuple,
                 labels: jnp.ndarray,
                 net: hk.Transformed
                 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Computes loss with net."""

  # APPLY NETWORK HERE
  preds = net.apply(params, graph).nodes # extract just nodes

  # do the appropriate train mask
  loss = compute_bce_with_logits_loss(preds, labels)

  return loss, preds

def compute_loss(params: hk.Params, preds: jnp.ndarray,
                 labels: jnp.ndarray,
                 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Computes loss."""

  # do the appropriate train mask
  loss = compute_bce_with_logits_loss(preds, labels)
  return loss, preds

def compute_roc_auc_score(preds: jnp.ndarray,
                          labels: jnp.ndarray) -> jnp.ndarray:
  """Computes roc auc (area under the curve) score for classification."""
  s = jax.nn.sigmoid(preds)
  roc_auc = roc_auc_score(labels, s)
  return roc_auc

In [23]:
train_reader

<torch_geometric.loader.random_node_loader.RandomNodeLoader at 0x7fdd6217e460>

In [24]:
# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/ogb_examples/train.py
def train(data_reader: Any, num_epochs: int, lr: float=1e-4) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  batch = next(iter(train_reader))

  graph, labels, masks = get_batched_padded_graph_tuples(batch)

  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)
  # Initialize the optimizer.
  opt_init, opt_update = optax.adam(lr)
  opt_state = opt_init(params)

  compute_loss_fn = functools.partial(compute_loss)

  compute_loss_fn = jax.jit(jax.value_and_grad(
      compute_loss_fn, has_aux=True))

  for epoch in range(num_epochs):




    total_loss = 0.
    for i,batch in enumerate(data_reader):

      graph, labels, masks = get_batched_padded_graph_tuples(batch)

      train_mask,valid_mask,test_mask = masks # unpack masks

      # apply net here and mask for training data
      train_preds = net.apply(params, graph).nodes[train_mask]
      train_labels = labels[train_mask]
    
      print("out nodes", train_preds.shape)

      (train_loss, train_preds), grad = compute_loss_fn(params, train_preds, train_labels)


      updates, opt_state = opt_update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)


#     if epoch % 10 == 0 or epoch == (num_epochs - 1):
#       train_roc_auc = compute_roc_auc_score(train_preds, train_labels)


#       val_loss, val_preds = compute_loss(params, graph, valid_mask, net)
#       val_roc_auc = compute_roc_auc_score(val_preds, labels[valid_mask])
#       print(f'epoch: {epoch}, train_loss: {train_loss:.3f}, '
#             f'train_roc_auc: {train_roc_auc:.3f}, val_loss: {val_loss:.3f}, '
#             f'val_roc_auc: {val_roc_auc:.3f}')


  print('Training finished')
  return params

In [None]:
params = train(train_reader, num_epochs=2)