## GNN Model Code

GNN model using molecular scent dataset from Leffingwell Odor Datset (loaded using Pyrfume - https://pyrfume.org)

Code below modified from example code given in the "Predicting DFT Energies with GNNs", "Interpretability and Deep Learning" sections of "Deep Learning for Molecules and Materials" textbook (https://dmol.pub/applied/QM9.html)

In [None]:
# Imports
import pyrfume
import tensorflow as tf
import numpy as np
import seaborn as sns
import jax
import jax.numpy as jnp
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw

import haiku as hk
import optax
import sklearn.metrics
import numpy as np

import warnings

warnings.filterwarnings("ignore")

np.random.seed(0)
tf.random.set_seed(0)

# Plotting style
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import urllib.request

urllib.request.urlretrieve(
    "https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
    "IBMPlexMono-Regular.ttf",
)
fe = font_manager.FontEntry(fname="IBMPlexMono-Regular.ttf", name="plexmono")
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update(
    {
        "axes.facecolor": "#f5f4e9",
        "grid.color": "#AAAAAA",
        "axes.edgecolor": "#333333",
        "figure.facecolor": "#FFFFFF",
        "axes.grid": False,
        "axes.prop_cycle": plt.cycler("color", plt.cm.Dark2.colors),
        "font.family": fe.name,
        "figure.figsize": (3.5, 3.5 / 1.2),
        "ytick.left": True,
        "xtick.bottom": True,
    }
)

### Model Training Related Code
NOTE: numEpochs is currently set to 2 & the dataset is downsampled, edit code when actually training/evaluating model to use the entire dataset & a larger number of empochs

In [None]:
# Save model inputs and hyperparameters
learning_rate = 1e-5
num_Dense_layers = 2
num_GNN_layers = 4
# NOTE: currently using reduced number of epochs, increase when training model (in paper used 138 epochs)
numEpochs = (
    2  # reduced value for checking notebook, edit when training/evaluating the model
)
steps_for_gradUpdate = 8
graph_feat_length = 512
node_feat_length = 256
message_feat_length = node_feat_length
weights_stddevGNN = 1e-2
earlyStopping = True
earlyStopping_patience = 3
earlyStopping_minDelta = 0
regularizationStrength = 1e-6

In [None]:
# Use train-test split given in Leffingwell Dataset - except add in validation set (have 70% train, 10% validation, 20% test rather than 80% train & 20% test)
# Load data
scentdata = pyrfume.load_data("leffingwell/leffingwell_data.csv", remote=True)

# Code used to create train, test & validation sets (based on the splits given in Leffingwell Dataset)
testData = scentdata[scentdata["labels_train/test"] == 0]
numTestData = len(testData)

trainAndValidationData = scentdata[scentdata["labels_train/test"] == 1]

numTrainAndValidationData = len(trainAndValidationData)
trainAndValidationData = trainAndValidationData.reset_index()

numMoleculesInDataset = numTestData + numTrainAndValidationData

# randomly select indices from trainAndValidation data - validation set = 10% of entire dataset
numValidationData = int(0.10 * numMoleculesInDataset)
validationIndices = np.random.choice(
    a=numTrainAndValidationData, size=numValidationData, replace=False
)  # https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html
validationData = trainAndValidationData.iloc[
    validationIndices
]  # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.iloc.html
trainData = trainAndValidationData.drop(
    index=validationIndices
)  # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.drop.html


# Use smaller set of data (just to check code does not crash)
# NOTE: comment out the next 3 lines when actually training/evaluating model to use the entire dataset
trainData = trainData.sample(frac=0.01, random_state=0).reset_index(drop=True)
validationData = validationData.sample(frac=0.01, random_state=0).reset_index(drop=True)
testData = testData.sample(frac=0.8, random_state=0).reset_index(
    drop=True
)  # Sample more from test set to avoid error that there is only 1 molecule with a certain scent when calculating AUROC score

In [None]:
# Code to generate list of all scent labels (scentClasses)
numMolecules = len(scentdata.odor_labels_filtered)
numClasses = 112  # No odorless class
scentClasses = pd.read_csv("scentClasses.csv")
scentClasses = scentClasses["Scent"].tolist()
moleculeScentList = []
for i in range(numMolecules):
    scentString = scentdata.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList.append(scentList)

# Generate moleculeScentList_train, moleculeScentList_test, moleculeScentList_validation
numTrainMolecules = len(trainData.odor_labels_filtered)
moleculeScentList_train = []
for i in range(numTrainMolecules):
    scentString = trainData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_train.append(scentList)

numValidationMolecules = len(validationData.odor_labels_filtered)
moleculeScentList_validation = []
for i in range(numValidationMolecules):
    scentString = validationData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_validation.append(scentList)

numTestMolecules = len(testData.odor_labels_filtered)
moleculeScentList_test = []
for i in range(numTestMolecules):
    scentString = testData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_test.append(scentList)

In [None]:
def gen_smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, node_feat_length))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
        # Add in whether atom is in a ring or not for one-hot encoding
        if i.IsInRing():
            nodes[i.GetIdx(), -1] = 1

    adj = np.zeros((N, N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v] = 1
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj

In [None]:
# Function that creates label vector given list of strings describing scent of molecule as input
# Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
# If label vector has 1 at index i, then molecule does have scent i


def createLabelVector(scentsList):
    # Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        # Find class index
        classIndex = scentClasses.index(scentsList[j])
        # print(classIndex)
        # print(scentsList[j])
        # print(scentClasses[classIndex])
        # Update label vector
        labelVector[classIndex] = 1
    return labelVector

In [None]:
def generateGraphsTrain():
    for i in range(numTrainMolecules):
        graph = gen_smiles2graph(trainData.smiles[i])
        labels = createLabelVector(moleculeScentList_train[i])
        yield graph, labels


def generateGraphsValidation():
    for i in range(numValidationMolecules):
        graph = gen_smiles2graph(validationData.smiles[i])
        labels = createLabelVector(moleculeScentList_validation[i])
        yield graph, labels


def generateGraphsTest():
    for i in range(numTestMolecules):
        graph = gen_smiles2graph(testData.smiles[i])
        labels = createLabelVector(moleculeScentList_test[i])
        yield graph, labels


def generateGraphs():
    for i in range(numMolecules):
        graph = gen_smiles2graph(scentdata.smiles[i])
        labels = createLabelVector(moleculeScentList[i])
        yield graph, labels


# Check that generateGraphs() works for 1st molecule
# print(gen_smiles2graph(scentdata.SMILES[0]))
# print(scentdata.SENTENCE[0].split(','))
# print(np.nonzero(createLabelVector(scentdata.SENTENCE[0].split(','))))
# print(scentClasses[89])

data = tf.data.Dataset.from_generator(
    generateGraphs,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)


train_set = tf.data.Dataset.from_generator(
    generateGraphsTrain,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

valid_set = tf.data.Dataset.from_generator(
    generateGraphsValidation,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

test_set = tf.data.Dataset.from_generator(
    generateGraphsTest,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

In [None]:
train_N = numTrainMolecules
valid_N = numValidationMolecules
test_N = numTestMolecules

In [None]:
class GNNLayer(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        # Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        # Edges is of shape (N,N) (adjacency matrix)
        # Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1]  # graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1]  # node_feature_len (Nf)
        message_feature_len = message_feat_length  # message_feature_length (Mf)

        # Initialize weights
        w_init = hk.initializers.RandomNormal(stddev=weights_stddevGNN)

        # we is of shape (Nf,Mf)
        we = hk.get_parameter(
            "we", shape=[node_feature_len, message_feature_len], init=w_init
        )

        # b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)

        # wv is of shape (Mf,Nf)
        wv = hk.get_parameter(
            "wv", shape=[message_feature_len, node_feature_len], init=w_init
        )

        # wu is of shape (Nf,Gf)
        wu = hk.get_parameter(
            "wu", shape=[node_feature_len, graph_feature_len], init=w_init
        )

        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(
            b
            + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
            @ we
            * edges[..., None]
        )

        # Uncomment lines below to update edges (also edit return line so new_edges is returned)
        # Update edges, use jnp.any to have new_edges be of shape N x N
        # new_edges = jnp.any(ek, axis=-1)

        # Normalize over edge features w/layer normalization
        # new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)

        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)

        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes  # Use leaky ReLU

        # Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(
            axis=[0, 1], create_scale=False, create_offset=False, eps=1e-05
        )(new_nodes)

        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)

        # dense layer for new features so new_features shape = Gf
        new_features = (
            jax.nn.leaky_relu(global_node_features @ wu) + features
        )  # Use leaky ReLU for activation

        return new_nodes, edges, new_features


def model_fn(x):
    nodes, edges = x
    features = jnp.ones(graph_feat_length)
    x = nodes, edges, features

    # NOTE: If edited num_GNN_layers, need to edit code below (increase or decrease # times have x = GNNLayer(...))
    # 4 GNN layers
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)

    # 2 dense layers
    logits = hk.Linear(numClasses)(x[-1])
    logits = hk.Linear(numClasses)(logits)

    return logits  # Model now returns logits


model = hk.without_apply_rng(hk.transform(model_fn))

In [None]:
# from https://dmol.pub/dl/xai.html
def cross_entropy_logits(logits, y):
    return jnp.mean(
        jnp.clip(logits, 0, None) - logits * y + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
    )


# Use loss function below if model outputs logits & do not want to use L2 regularization
def loss_fn_logits(params, x, y):
    logits = model.apply(params, x)
    return cross_entropy_logits(logits, y)


# Use loss function below if model outputs logits & want to include L2 regularization
# Code to compute L2 regularization based on that in the "MLP on MNIST" Example on the Haiku Github repository (https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py)
def loss_fn_logits_reg(params, x, y):
    l2_lossTerm = regularizationStrength * sum(
        jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)
    )
    logits = loss_fn_logits(params, x, y)
    return logits + l2_lossTerm

In [None]:
rng = jax.random.PRNGKey(0)

sampleData = data.take(1)

for dataVal in sampleData:
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i, edges_i)

params = model.init(rng, xi)

opt_init, opt_update = optax.chain(
    optax.apply_every(k=steps_for_gradUpdate), optax.adam(learning_rate)
)

opt_state = opt_init(params)


@jax.jit
def update(opt_state, x, y, params):
    value, grads = jax.value_and_grad(loss_fn_logits_reg)(params, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    updated_params = optax.apply_updates(params, updates)
    return value, opt_state, updated_params

In [None]:
# Train model
epochs = numEpochs
print(
    f"Number of Epochs: {epochs}, learning rate: {learning_rate}, node_feature_len: {node_feat_length}, graph_feature_len: {graph_feat_length}, message_feature_length: {message_feat_length}, {num_Dense_layers} Dense, {num_GNN_layers} GNN layers"
)
val_loss = np.zeros(epochs)
train_loss = np.zeros(epochs)

# early stopping counter
counter = 0
epochStoppedAt = epochs

for e in range(epochs):
    if counter == earlyStopping_patience:
        print(f"Early stopping, stopped at Epoch {e} (Note 1st epoch = 0)")
        epochStoppedAt = e
        break  # Early stopping

    for i, elementInTrainSet in enumerate(train_set):
        (ni, ei), yi = elementInTrainSet
        ni = ni.numpy()
        ei = ei.numpy()
        yi = yi.numpy()
        xi = ni, ei
        value, opt_state, params = update(opt_state, xi, yi, params)
        train_loss[e] += value

    train_loss[e] = train_loss[e] / train_N  # Take average loss over all molecules
    print(f"Training Loss, Epoch {e}: {train_loss[e]}")

    for j, v in enumerate(valid_set):
        (n_val, e_val), y = v
        n_val = n_val.numpy()
        e_val = e_val.numpy()
        y = y.numpy()
        x = n_val, e_val
        loss = loss_fn_logits_reg(params, x, y)
        val_loss[e] += loss

    val_loss[e] = val_loss[e] / valid_N  # Take average loss over all molecules

    # Check if have improvement/increase in validation loss (early stopping)
    if e > 0:
        lossDiff = (
            prevValidLoss - val_loss[e]
        )  # If have improvement, prevValidLoss > val_loss[e]
        if lossDiff < earlyStopping_minDelta:
            counter += 1
        else:
            counter = 0

    prevValidLoss = val_loss[e]

    print(f"Epoch {e}, Validation Loss: {val_loss[e]}")


opt_params = params
# Save optimal parameters
opt_params_flattened = jax.tree_util.tree_flatten(opt_params)
# fileName = f'optParams_{epochs}Epochs.npy'
# np.save(fileName, opt_params_flattened[0])

### Model Evaluation Related Code

#### Run 2 cells below if reading model parameters from file & have not run earlier cells in the notebook
NOTE: the dataset is downsampled, edit code when actually training/evaluating model to use the entire dataset 

In [None]:
# NOTE: If reading model parameters from file (rather than computing metrics directly after training), run 2 cells below

# Parameters for GNN model (parameters being read in)
node_feat_length = 256
message_feat_length = 256
graph_feat_length = 512
weights_stddevGNN = 0.01

# Use train-test split given in Leffingwell Dataset - except add in validation set (have 70% train, 10% validation, 20% test rather than 80% train & 20% test)
# Load data
scentdata = pyrfume.load_data("leffingwell/leffingwell_data.csv", remote=True)

# Code used to create train, test & validation sets (based on the splits given in Leffingwell Dataset) & write to csv file
testData = scentdata[scentdata["labels_train/test"] == 0]
numTestData = len(testData)

trainAndValidationData = scentdata[scentdata["labels_train/test"] == 1]

numTrainAndValidationData = len(trainAndValidationData)
trainAndValidationData = trainAndValidationData.reset_index()

numMoleculesInDataset = numTestData + numTrainAndValidationData

# randomly select indices from trainAndValidation data - validation set = 10% of entire dataset
numValidationData = int(0.10 * numMoleculesInDataset)
validationIndices = np.random.choice(
    a=numTrainAndValidationData, size=numValidationData, replace=False
)  # https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html
validationData = trainAndValidationData.iloc[
    validationIndices
]  # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.iloc.html
trainData = trainAndValidationData.drop(
    index=validationIndices
)  # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.drop.html


# Use smaller set of data (just to check code does not crash)
# NOTE: comment out the next 3 lines when actually training/evaluating model to use the entire dataset
trainData = trainData.sample(frac=0.01, random_state=0).reset_index(drop=True)
validationData = validationData.sample(frac=0.01, random_state=0).reset_index(drop=True)
# testData = testData.sample(frac=0.8, random_state=0).reset_index(
#    drop=True
# )  # Sample more from test set to avoid error that there is only 1 molecule with a certain scent when calculating AUROC score

# Code to generate list of all scent labels (scentClasses)
numMolecules = len(scentdata.odor_labels_filtered)
numClasses = 112  # No odorless class
scentClasses = pd.read_csv("scentClasses.csv")
scentClasses = scentClasses["Scent"].tolist()
moleculeScentList = []
for i in range(numMolecules):
    scentString = scentdata.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList.append(scentList)

# Generate moleculeScentList_train, moleculeScentList_test, moleculeScentList_validation
numTrainMolecules = len(trainData.odor_labels_filtered)
moleculeScentList_train = []
for i in range(numTrainMolecules):
    scentString = trainData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_train.append(scentList)

numValidationMolecules = len(validationData.odor_labels_filtered)
moleculeScentList_validation = []
for i in range(numValidationMolecules):
    scentString = validationData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_validation.append(scentList)

numTestMolecules = len(testData.odor_labels_filtered)
moleculeScentList_test = []
for i in range(numTestMolecules):
    scentString = testData.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList_test.append(scentList)


def gen_smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, node_feat_length))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
        # Add in whether atom is in a ring or not for one-hot encoding
        if i.IsInRing():
            nodes[i.GetIdx(), -1] = 1

    adj = np.zeros((N, N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v] = 1
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj


# Function that creates label vector given list of strings describing scent of molecule as input
# Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
# If label vector has 1 at index i, then molecule does have scent i


def createLabelVector(scentsList):
    # Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        # Find class index
        classIndex = scentClasses.index(scentsList[j])
        # print(classIndex)
        # print(scentsList[j])
        # print(scentClasses[classIndex])
        # Update label vector
        labelVector[classIndex] = 1
    return labelVector


def generateGraphsTrain():
    for i in range(numTrainMolecules):
        graph = gen_smiles2graph(trainData.smiles[i])
        labels = createLabelVector(moleculeScentList_train[i])
        yield graph, labels


def generateGraphsValidation():
    for i in range(numValidationMolecules):
        graph = gen_smiles2graph(validationData.smiles[i])
        labels = createLabelVector(moleculeScentList_validation[i])
        yield graph, labels


def generateGraphsTest():
    for i in range(numTestMolecules):
        graph = gen_smiles2graph(testData.smiles[i])
        labels = createLabelVector(moleculeScentList_test[i])
        yield graph, labels


def generateGraphs():
    for i in range(numMolecules):
        graph = gen_smiles2graph(scentdata.smiles[i])
        labels = createLabelVector(moleculeScentList[i])
        yield graph, labels


# Get graph data for training, testing & validation sets
data = tf.data.Dataset.from_generator(
    generateGraphs,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)


train_set = tf.data.Dataset.from_generator(
    generateGraphsTrain,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

valid_set = tf.data.Dataset.from_generator(
    generateGraphsValidation,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

test_set = tf.data.Dataset.from_generator(
    generateGraphsTest,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

train_N = numTrainMolecules
valid_N = numValidationMolecules
test_N = numTestMolecules


class GNNLayer(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        # Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        # Edges is of shape (N,N) (adjacency matrix)
        # Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1]  # graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1]  # node_feature_len (Nf)
        message_feature_len = message_feat_length  # message_feature_length (Mf)

        # Initialize weights
        w_init = hk.initializers.RandomNormal(stddev=weights_stddevGNN)

        # we is of shape (Nf,Mf)
        we = hk.get_parameter(
            "we", shape=[node_feature_len, message_feature_len], init=w_init
        )

        # b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)

        # wv is of shape (Mf,Nf)
        wv = hk.get_parameter(
            "wv", shape=[message_feature_len, node_feature_len], init=w_init
        )

        # wu is of shape (Nf,Gf)
        wu = hk.get_parameter(
            "wu", shape=[node_feature_len, graph_feature_len], init=w_init
        )

        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(
            b
            + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
            @ we
            * edges[..., None]
        )

        # Uncomment lines below to update edges (also edit return line so new_edges is returned)
        # Update edges, use jnp.any to have new_edges be of shape N x N
        # new_edges = jnp.any(ek, axis=-1)

        # Normalize over edge features w/layer normalization
        # new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)

        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)

        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes  # Use leaky ReLU

        # Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(
            axis=[0, 1], create_scale=False, create_offset=False, eps=1e-05
        )(new_nodes)

        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)

        # dense layer for new features so new_features shape = Gf
        new_features = (
            jax.nn.leaky_relu(global_node_features @ wu) + features
        )  # Use leaky ReLU for activation

        # just return features for ease of use
        return new_nodes, edges, new_features


def model_fn(x):
    nodes, edges = x
    features = jnp.ones(graph_feat_length)
    x = nodes, edges, features

    # NOTE: If edited num_GNN_layers, need to edit code below (increase or decrease # times have x = GNNLayer(...))
    # 4 GNN layers
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)

    # 2 dense layers
    logits = hk.Linear(numClasses)(x[-1])
    # logits = jax.nn.relu(logits) #ReLU activation between dense layer
    logits = hk.Linear(numClasses)(logits)

    return logits  # Model now returns logits


model = hk.without_apply_rng(hk.transform(model_fn))

# Initialize model
rng = jax.random.PRNGKey(0)
sampleData = data.take(1)
for dataVal in sampleData:  # Look into later how to get larger set
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i, edges_i)

params = model.init(rng, xi)

In [None]:
# Load optimal parameters for GNN model
print("Edit fileName to change parameters being loaded")
fileName = "optParams_dry-waterfall-17.npy"  # Currently optimal parameters, edit when get better model
paramsArr = jnp.load(fileName, allow_pickle=True)
opt_params = {
    "gnn_layer": {
        "b": paramsArr[0],
        "we": paramsArr[1],
        "wu": paramsArr[2],
        "wv": paramsArr[3],
    },
    "gnn_layer_1": {
        "b": paramsArr[4],
        "we": paramsArr[5],
        "wu": paramsArr[6],
        "wv": paramsArr[7],
    },
    "gnn_layer_2": {
        "b": paramsArr[8],
        "we": paramsArr[9],
        "wu": paramsArr[10],
        "wv": paramsArr[11],
    },
    "gnn_layer_3": {
        "b": paramsArr[12],
        "we": paramsArr[13],
        "wu": paramsArr[14],
        "wv": paramsArr[15],
    },
    "linear": {"b": paramsArr[16], "w": paramsArr[17]},
    "linear_1": {"b": paramsArr[18], "w": paramsArr[19]},
}

#### Calculate AUROC values on test set

In [None]:
# Compute AUROC & Create ROC curve for each scent class - uses scikit-learn
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html

test_yhat = np.empty(
    (test_N, numClasses)
)  # create empty array to store predictions on test set
test_y = np.empty((test_N, numClasses))


for i, testVal in enumerate(test_set):
    (nodes_i, edges_i), yi = testVal
    nodes_i = nodes_i.numpy()
    edges_i = edges_i.numpy()
    yi = yi.numpy()
    xi = nodes_i, edges_i
    test_yhat[i] = jax.nn.sigmoid(model.apply(opt_params, xi))
    test_y[i] = yi

# Shape of test_yhat and test_y should be (n_samples, n_classes)
# print(f'Shape of test_y: {np.shape(test_y)} and test_yhat: {np.shape(test_yhat)}')

scentClasses_testSet = []
aurocs_allClasses = []

for c in range(numClasses):
    if np.count_nonzero(test_y[:, c]) == 0:
        print(f"Test set does not have any molecules with scent {scentClasses[c]}")
    else:
        ##Uncomment lines below to create ROC curves
        # fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = test_y[:,c], y_score = test_yhat[:,c])
        # plt.plot(fpr, tpr, '-o', label='Trained Model')
        # plt.plot([0,1], [0, 1], label='Naive Classifier')
        # plt.ylabel('True Positive Rate')
        # plt.xlabel('False Positive Rate')
        # plt.title(f'ROC Curve for {scentClasses[c]}')
        # plt.legend()
        # plt.show()
        # plt.savefig(f'GNN_ROC_Curve_{scentClasses[c]}.jpg')
        # plt.close()
        scentClasses_testSet.append(scentClasses[c])
        auroc = sklearn.metrics.roc_auc_score(
            y_true=test_y[:, c], y_score=test_yhat[:, c]
        )
        aurocs_allClasses.append(auroc)
        print(f"AUROC for scent {scentClasses[c]}: {auroc}")

# Write AUROC results to csv file
# aurocTable = pd.DataFrame({'Scent': scentClasses_testSet,'AUROC': aurocs_allClasses})
# csvFileName  = f'AurocTable_{runName}.csv'
# aurocTable.to_csv(csvFileName,index=False)

# Print Mean AUROC
mean_AUROC = np.mean(aurocs_allClasses)
print(f"Mean AUROC: {mean_AUROC}")

In [None]:
# Check that calculating AUROC for each scent class using method above & taking the mean of it is equivalent to using average=macro parameter
auroc_sklearnAverageMacro = sklearn.metrics.roc_auc_score(
    y_true=test_y, y_score=test_yhat, average="macro"
)
print(mean_AUROC == auroc_sklearnAverageMacro)
print(f"macro-average AUROC: {auroc_sklearnAverageMacro}")

# Calculate micro-average AUROC
auroc_sklearnAverageMicro = sklearn.metrics.roc_auc_score(
    y_true=test_y, y_score=test_yhat, average="micro"
)
print(f"micro-average AUROC: {auroc_sklearnAverageMicro}")

# Calculate weighted AUROC (each class weighted by how many times it occurs in the true data sample)
auroc_sklearnAverageWeighted = sklearn.metrics.roc_auc_score(
    y_true=test_y, y_score=test_yhat, average="weighted"
)
print(f"weighted-average AUROC: {auroc_sklearnAverageWeighted}")

# Calculate median AUROC (find median value of AUROC for each scent class)
auroc_median = np.median(aurocs_allClasses)
print(f"median AUROC: {auroc_median}")

# Print out values for each AUROC score value as pandas table
tableOfAUROCValues = pd.DataFrame(index=None)
tableOfAUROCValues["macro-average AUROC"] = [auroc_sklearnAverageMacro]
tableOfAUROCValues["micro-average AUROC"] = [auroc_sklearnAverageMicro]
tableOfAUROCValues["weighted-average AUROC"] = [auroc_sklearnAverageWeighted]
tableOfAUROCValues["median AUROC"] = [auroc_median]
tableOfAUROCValues