### GNN Model Code

GNN model using molecular scent dataset from AI crowd (https://www.aicrowd.com/challenges/learning-to-smell)

Code below modified from example code given in the "Predicting DFT Energies with GNNs" and "Interpretability and Deep Learning" sections of "Deep Learning for Molecules and Materials" textbook (https://whitead.github.io/dmol-book/applied/QM9.html)

In [1]:
print('Remember to update CUDA_VISIBLE_DEVICES')
#For GPU nodes, edit value below based on allocated GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Remember to update CUDA_VISIBLE_DEVICES


In [2]:
#!pip install matplotlib numpy pandas seaborn jax jaxlib dm-haiku tensorflow wandb optax

#Code uses Weights & Biases
import wandb
#If running code in notebook & have not yet logged in w/it into W&B, uncomment lines below
#wandb.login()
#%env "WANDB_NOTEBOOK_NAME" "GNN Model 4_Latest"

#Other imports
import tensorflow as tf
import numpy as np
import seaborn as sns
import jax.numpy as jnp
import jax
import jax.experimental.optimizers as opt
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import matplotlib as mpl
import matplotlib.pyplot as plt
import haiku as hk
import optax
import sklearn.metrics

import warnings
warnings.filterwarnings('ignore')
sns.set_context('notebook')
sns.set_style('dark',  {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
                        'axes.edgecolor': '#666666', 'axes.linewidth':     0.8 , 'figure.dpi': 300})
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle) 
np.random.seed(0)
tf.random.set_seed(0)

In [None]:
# 1. Start a W&B run
run = wandb.init(project='GNN_Model', entity='aseshad4')

In [None]:
# 2. Save model inputs and hyperparameters
config = wandb.config
config.learning_rate = 1e-5
config.num_Dense_layers = 2
##NOTE: If changing # GNN layers, need to edit code in model_fn & code for storing parameters to take average of last 10 epochs
config.num_GNN_layers = 4 
config.numEpochs = 150
config.steps_for_gradUpdate = 8
config.graph_feat_length = 256
config.node_feat_length = 256
config.message_feat_length = config.node_feat_length
config.weights_stddevGNN = 1e-2
config.layerNormalization = True
config.layerNormalization_edges = True
config.edgeUpdates = True

In [None]:
#Load data --> file uploaded to jhub (locally stored)
scentdata = pd.read_csv('train.csv')

#Read in vocabulary text file --> this file gives all of the scent classes used in dataset
file = open('vocabulary.txt')
#Create list that stores all scent classes
scentClasses = file.read().split('\n')
numClasses = len(scentClasses)


In [None]:
def gen_smiles2graph(sml):
    '''Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    '''
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {rdkit.Chem.rdchem.BondType.SINGLE: 1,
                    rdkit.Chem.rdchem.BondType.DOUBLE: 2,
                    rdkit.Chem.rdchem.BondType.TRIPLE: 3,
                    rdkit.Chem.rdchem.BondType.AROMATIC: 4}
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N,config.node_feat_length))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
    
    adj = np.zeros((N,N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(),j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(),j.GetEndAtomIdx())        
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning('Ignoring bond order' + order)
        adj[u, v] = 1       
        adj[v, u] = 1 
    adj += np.eye(N)
    return nodes, adj

In [None]:
#Function that creates label vector given list of strings describing scent of molecule as input
#Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
#If label vector has 1 at index i, then molecule does have scent i

def createLabelVector(scentsList):
    #Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        #Find class index
        classIndex = scentClasses.index(scentsList[j])
        #print(classIndex)
        #print(scentsList[j])
        #print(scentClasses[classIndex])
        #Update label vector
        labelVector[classIndex] = 1
    return labelVector

In [None]:
def generateGraphs():
    for i in range(len(scentdata)):
        graph = gen_smiles2graph(scentdata.SMILES[i])   
        tempScents = scentdata.SENTENCE[i].split(',') #Create list of strings describing scent of molecule i
        labels = createLabelVector(tempScents)
        yield graph, labels

#Check that generateGraphs() works for 1st molecule
#print(gen_smiles2graph(scentdata.SMILES[0]))
#print(scentdata.SENTENCE[0].split(','))
#print(np.nonzero(createLabelVector(scentdata.SENTENCE[0].split(','))))
#print(scentClasses[89])
data = tf.data.Dataset.from_generator(generateGraphs, output_types=((tf.float32, tf.float32), tf.float32), 
                                      output_shapes=((tf.TensorShape([None, config.node_feat_length]), tf.TensorShape([None, None])), tf.TensorShape([None])))

#Generates graphs where molecule i is associated with molecule i+1's scent (scrambles scents & molecular features)
#NOT correct way to generate graphs
def generateGraphs_scrambled():
    for i in range(len(scentdata)):
        graph = gen_smiles2graph(scentdata.SMILES[i])
        if(i+1 == len(scentdata)):
            tempScents = scentdata.SENTENCE[0].split(',') #Create list of strings describing scent of molecule i
        else:
            tempScents = scentdata.SENTENCE[i+1].split(',') #Create list of strings describing scent of molecule i

        labels = createLabelVector(tempScents)
        yield graph, labels

data_scrambled = tf.data.Dataset.from_generator(generateGraphs_scrambled, output_types=((tf.float32, tf.float32), tf.float32), 
                                      output_shapes=((tf.TensorShape([None, config.node_feat_length]), tf.TensorShape([None, None])), tf.TensorShape([None])))



In [None]:
numMolecules = len(scentdata.SMILES)
print(f'Number of molecules: {numMolecules}')

#Split data into training, testing & validation sets

#shuffle dataset
#shuffled_data = data.shuffle(numMolecules)

train_N = int(numMolecules * 0.8)
valid_N = int(numMolecules * 0.1)
test_N = numMolecules - train_N - valid_N

train_set = data.take(train_N)
valid_set = data.skip(train_N).take(valid_N)
test_set = data.skip(valid_N+train_N).take(test_N)

#train_set_scrambled = data_scrambled.take(train_N)
#valid_set_scrambled = data_scrambled.skip(train_N).take(valid_N)
#test_set_scrambled = data_scrambled.skip(valid_N+train_N).take(test_N)

In [None]:
def cross_entropy(yhat, y):
    #yhat = jax.nn.sigmoid(logits)
    return -jnp.mean(y * jnp.log(yhat + 1e-10) + (1 - y) * jnp.log(1 - yhat + 1e-10))

def cross_entropy_logits(logits, y):
    '''Cross entropy without sigmoid. Works with logits directly'''
    return jnp.mean(jnp.clip(logits, 0, None) - logits * y + jnp.log(1 + jnp.exp(-jnp.abs(logits))))


class GNNLayer(hk.Module): #TODO: If increase number of layers, stack features & new_features and shrink via dense layer

    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        #Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        #Edges is of shape (N,N) (adjacency matrix)
        #Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1] #graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1] #node_feature_len (Nf)
        message_feature_len = config.message_feat_length #message_feature_length (Mf)
        
        #Initialize weights
        w_init = hk.initializers.RandomNormal(stddev = config.weights_stddevGNN)
        
        #we is of shape (Nf,Mf)
        we = hk.get_parameter("we", shape=[node_feature_len, message_feature_len], init=w_init)
        
        #b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)
        
        #wv is of shape (Mf,Nf)
        wv = hk.get_parameter("wv", shape=[message_feature_len, node_feature_len], init=w_init)
        
        #wu is of shape (Nf,Gf)
        wu = hk.get_parameter("wu", shape=[node_feature_len, graph_feature_len], init=w_init)
        
        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(b + 
            jnp.repeat(nodes[jnp.newaxis,...], nodes.shape[0], axis=0) @ we * edges[..., None])

        #ek *= edges[...,None]
        
        #Update edges, use jnp.any to have new_edges be of shape N x N
        new_edges = jnp.any(ek, axis=-1)
        
        #Normalize over edge features w/layer normalization
        new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)
    
        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)
        
        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes
        
        #Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_nodes)
        
        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)
        
        # dense layer for new features so new_features shape = Gf
        new_features = jax.nn.leaky_relu(global_node_features  @ wu) + features
        
        # just return features for ease of use
        return new_nodes, new_edges, new_features

    
def model_fn(x):
    nodes, edges = x
    features = jnp.ones(config.graph_feat_length)
    x = nodes, edges, features
    
    #NOTE: If edited config.num_GNN_layers, need to edit code below (increase or decrease # times have x = GNNLayer(...))
    # 4 GNN layers
    x = GNNLayer(output_size=config.graph_feat_length)(x)
    x = GNNLayer(output_size=config.graph_feat_length)(x)
    x = GNNLayer(output_size=config.graph_feat_length)(x)
    x = GNNLayer(output_size=config.graph_feat_length)(x)
    
    # 2 dense layers
    logits = hk.Linear(numClasses)(x[-1])
    logits = hk.Linear(numClasses)(logits)

    return logits #Model now returns logits

model = hk.without_apply_rng(hk.transform(model_fn))

#Use loss function below if model outputs yhat directly (not logits)
def loss_fn(params, x, y):    
    yhat = model.apply(params, x)
    return cross_entropy(yhat, y) 

#Use loss function below if model outputs logits
def loss_fn_logits(params, x, y):    
    logits = model.apply(params, x)
    return cross_entropy_logits(logits, y)

#Accuracy function where accuracy is measured as |intersection of true and predicted labels|/|union of true and predicted labels|
def accuracy_fn(params, x, y): 
    yhat = model.apply(params, x)
    true_scentIndices = jnp.nonzero(y)
    
    # convert from prob to hard class -> positive yhat -> yhat = 1, else 0
    hard_yhat = np.where(yhat > 0, np.ones_like(yhat), np.zeros_like(yhat))
    predicted_scentIndices = jnp.nonzero(hard_yhat)
    correctlyPredicted = len(np.intersect1d(predicted_scentIndices, true_scentIndices))
    numTrueLabels = np.size(true_scentIndices)

    #Total number of labels = number of labels in union of predicted & actual/true labels set
    ##The size of this set = Actual labels - those correctly predicted + all predicted labels
    numPredLabels = np.size(predicted_scentIndices)
    totalLabels = numTrueLabels - correctlyPredicted + numPredLabels
    return correctlyPredicted/totalLabels


In [None]:
#Compute competition accuracy
#accuracy is measured as |intersection of true and predicted labels for top 3 predictions|/|union of true and predicted labels for top 3 predictions|
def competitionAccuracy_fn(params, x, y): 
    yhat = model.apply(params, x)
    numTrueLabels = jnp.count_nonzero(y) 
    true_scentIndices = jnp.nonzero(y)
    pred_sortedIndices = np.argsort(yhat)

    top15Pred = pred_sortedIndices[len(pred_sortedIndices)-15:]
    #Create array storing top 5 predictions
    predictions = np.zeros((5,3))
    for j in range(5):
        index = 15 - (j+1)*3
        predictions[j] = top15Pred[index:15-j*3]
        
    numCorrect = np.empty(5)
    for k in range(5):
        numCorrect[k] = len(np.intersect1d(predictions[k], true_scentIndices))
   
    topNumCorrect = np.amax(numCorrect)
    topPredictionSet = np.argmax(numCorrect)
    #Total number of labels is number of labels in union of predicted & actual/true labels set (keep only 3 true labels)
    ##The size of this set = Predicted labels - those accounted for in actual labels + all predicted labels
    totalLabels = 3 + (3-topNumCorrect)
    accuracyComp = topNumCorrect/totalLabels
    return accuracyComp, topPredictionSet

In [None]:
rng = jax.random.PRNGKey(0)

sampleData = data.take(1)
print('Using Regular Dataset')

#sampleData_scrambled = data_scrambled.take(1)
#print('Using Scrambled Dataset')

for dataVal in sampleData: #Look into later how to get larger set
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i,edges_i)

params = model.init(rng, xi)

opt_init, opt_update = optax.chain(optax.apply_every(k=config.steps_for_gradUpdate), optax.adam(config.learning_rate))

opt_state = opt_init(params)

@jax.jit
def update(opt_state, x, y, params):
    value, grads = jax.value_and_grad(loss_fn_logits)(params, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    updated_params = optax.apply_updates(params, updates)
    return value, opt_state, updated_params

In [None]:
# Train model
epochs = config.numEpochs
print(f'Number of Epochs: {epochs}, learning rate: {config.learning_rate}, node_feature_len: {config.node_feat_length}, graph_feature_len: {config.graph_feat_length}, message_feature_length: {config.message_feat_length}, {config.num_Dense_layers} Dense, {config.num_GNN_layers} GNN layers')
val_loss = np.zeros(epochs)
val_accuracy = np.zeros(epochs)
val_accComp = np.zeros(epochs)
train_loss = np.zeros(epochs)
train_accuracy = np.zeros(epochs)
train_accComp = np.zeros(epochs)

#Create arrays to store parameters for last 10 epochs
pastTenParams_denseLayer1_w = np.zeros((10,config.graph_feat_length, numClasses))
pastTenParams_denseLayer1_b = np.zeros((10,numClasses))
pastTenParams_denseLayer2_w = np.zeros((10,numClasses, numClasses))
pastTenParams_denseLayer2_b = np.zeros((10,numClasses))

#NOTE: If edited config.num_GNN_layers, need to edit code below (add or remove pastTenParams_GNNLayerN...)
pastTenParams_GNNLayer1_b = np.zeros((10,config.message_feat_length))
pastTenParams_GNNLayer1_we = np.zeros((10,config.node_feat_length,config.message_feat_length))
pastTenParams_GNNLayer1_wu = np.zeros((10,config.node_feat_length, config.graph_feat_length))
pastTenParams_GNNLayer1_wv = np.zeros((10,config.message_feat_length,config.node_feat_length))
pastTenParams_GNNLayer2_b = np.zeros((10,config.message_feat_length))
pastTenParams_GNNLayer2_we = np.zeros((10,config.node_feat_length,config.message_feat_length))
pastTenParams_GNNLayer2_wu = np.zeros((10,config.node_feat_length, config.graph_feat_length))
pastTenParams_GNNLayer2_wv = np.zeros((10,config.message_feat_length,config.graph_feat_length))
pastTenParams_GNNLayer3_b = np.zeros((10,config.message_feat_length))
pastTenParams_GNNLayer3_we = np.zeros((10,config.node_feat_length,config.message_feat_length))
pastTenParams_GNNLayer3_wu = np.zeros((10,config.node_feat_length, config.graph_feat_length))
pastTenParams_GNNLayer3_wv = np.zeros((10,config.message_feat_length,config.graph_feat_length))
pastTenParams_GNNLayer4_b = np.zeros((10,config.message_feat_length))
pastTenParams_GNNLayer4_we = np.zeros((10,config.node_feat_length,config.message_feat_length))
pastTenParams_GNNLayer4_wu = np.zeros((10,config.node_feat_length, config.graph_feat_length))
pastTenParams_GNNLayer4_wv = np.zeros((10,config.message_feat_length,config.graph_feat_length))

#early stopping counter
#counter = 0
#epochStoppedAt = epochs

for e in range(epochs):
    #if(counter == 3):
       # print('early stopping')
       # epochStoppedAt = e
       # break #Early stopping
    for i, elementInTrainSet in enumerate(train_set):
        (ni,ei), yi = elementInTrainSet
        ni = ni.numpy()
        ei = ei.numpy()
        yi = yi.numpy()
        xi = ni,ei
        value, opt_state, params = update(opt_state, xi, yi, params)
        train_loss[e] += value
        train_accuracy[e] += accuracy_fn(params, xi,yi)
        accComp,index = competitionAccuracy_fn(params,xi,yi)
        train_accComp[e] += accComp
              
    train_loss[e] = train_loss[e]/ train_N #Take average loss over all molecules
    train_accuracy[e] = train_accuracy[e]/train_N
    train_accComp[e] = train_accComp[e]/train_N
    print(f'Training Loss, Epoch {e}: {train_loss[e]}')
    print(f'Training Comp. Accuracy, Epoch{e}: {train_accComp[e]}')
   
    for j, v in enumerate(valid_set):
        (n_val,e_val), y = v
        n_val = n_val.numpy()
        e_val = e_val.numpy()
        y = y.numpy()
        x = n_val,e_val
        if (e >= epochs - 10): #Store last 10 parameters
            paramsArr = jax.tree_util.tree_flatten(params)
            pos = 10 - (epochs - e)
            #NOTE: If edited config.num_GNN_layers, need to edit code below (add or remove assignment to pastTenParams_GNNLayer)
            #In paramsArr, GNN layers are first sets of 4 parameters (b,we,wu,wv), then have parameters for Dense layers
            pastTenParams_GNNLayer1_b[pos] = paramsArr[0][0]
            pastTenParams_GNNLayer1_we[pos] = paramsArr[0][1]
            pastTenParams_GNNLayer1_wu[pos] = paramsArr[0][2]
            pastTenParams_GNNLayer1_wv[pos] = paramsArr[0][3]
            pastTenParams_GNNLayer2_b[pos] = paramsArr[0][4]
            pastTenParams_GNNLayer2_we[pos] = paramsArr[0][5]
            pastTenParams_GNNLayer2_wu[pos] = paramsArr[0][6]
            pastTenParams_GNNLayer2_wv[pos] = paramsArr[0][7]
            pastTenParams_GNNLayer3_b[pos] = paramsArr[0][8]
            pastTenParams_GNNLayer3_we[pos] = paramsArr[0][9]
            pastTenParams_GNNLayer3_wu[pos] = paramsArr[0][10]
            pastTenParams_GNNLayer3_wv[pos] = paramsArr[0][11]
            pastTenParams_GNNLayer4_b[pos] = paramsArr[0][12]
            pastTenParams_GNNLayer4_we[pos] = paramsArr[0][13]
            pastTenParams_GNNLayer4_wu[pos] = paramsArr[0][14]
            pastTenParams_GNNLayer4_wv[pos] = paramsArr[0][15]
            pastTenParams_denseLayer1_b[pos] = paramsArr[0][16]
            pastTenParams_denseLayer1_w[pos] = paramsArr[0][17]
            pastTenParams_denseLayer2_b[pos] = paramsArr[0][18]
            pastTenParams_denseLayer2_w[pos] = paramsArr[0][19]

        loss = loss_fn_logits(params, x, y)
        val_loss[e] += loss
        val_accuracy[e] += accuracy_fn(params,x,y)
        vAccComp, vIndex = competitionAccuracy_fn(params,x,y)
        val_accComp[e] += vAccComp
    
    val_loss[e] = val_loss[e] / valid_N #Take average loss over all molecules
    val_accuracy[e] = val_accuracy[e] / valid_N
    val_accComp[e] = val_accComp[e] / valid_N
    
    #if (e > 0 and val_loss[e] > prevValidLoss): #Check if have increase in validation loss (early stopping)
    #    counter += 1
    #else:
    #    counter = 0
    #prevValidLoss = val_loss[e]

    print(f'Epoch {e}, Validation Loss: {val_loss[e]}')
    print(f'Validation Comp. Accuracy, Epoch{e}: {val_accComp[e]}')
    
    # 3. Log metrics over time to visualize performance (using Weights & Biases)
    wandb.log({'Training loss': train_loss[e], 'Epoch': e})   
    wandb.log({"Training accuracy (competition accuracy)": train_accComp[e], 'Epoch': e})    
    wandb.log({"Training accuracy (standard accuracy)": train_accuracy[e], 'Epoch': e})    
    wandb.log({"Validation loss": val_loss[e], 'Epoch': e})    
    wandb.log({"Validation accuracy (competition accuracy)": val_accComp[e], 'Epoch': e})    
    wandb.log({"Validation accuracy (standard accuracy)": val_accuracy[e], 'Epoch': e})    
    
    
#save W&B run
run.save()
runName = run.name

opt_params = params
#Save optimal parameters
opt_params_flattened = jax.tree_util.tree_flatten(opt_params)
fileName = f'optParams_{epochs}Epochs_{runName}.npy'
jnp.save(fileName, opt_params_flattened[0])

#NOTE: If edited config.num_GNN_layers, need to edit code below (add/remove opt_params_GNNN_b/we/wu/wv = jnp.mean(...))
#Take average of parameter values for last ten epochs & use them as weights
opt_params_dense1_b = jnp.mean(pastTenParams_denseLayer1_b,0)
opt_params_dense1_w = jnp.mean(pastTenParams_denseLayer1_w,0)
opt_params_dense2_b = jnp.mean(pastTenParams_denseLayer2_b,0)
opt_params_dense2_w = jnp.mean(pastTenParams_denseLayer2_w,0)
opt_params_GNN1_b = jnp.mean(pastTenParams_GNNLayer1_b,0)
opt_params_GNN1_we = jnp.mean(pastTenParams_GNNLayer1_we,0)
opt_params_GNN1_wu = jnp.mean(pastTenParams_GNNLayer1_wu,0)
opt_params_GNN1_wv = jnp.mean(pastTenParams_GNNLayer1_wv,0)
opt_params_GNN2_b = jnp.mean(pastTenParams_GNNLayer2_b,0)
opt_params_GNN2_we = jnp.mean(pastTenParams_GNNLayer2_we,0)
opt_params_GNN2_wu = jnp.mean(pastTenParams_GNNLayer2_wu,0)
opt_params_GNN2_wv = jnp.mean(pastTenParams_GNNLayer2_wv,0)
opt_params_GNN3_b = jnp.mean(pastTenParams_GNNLayer3_b,0)
opt_params_GNN3_we = jnp.mean(pastTenParams_GNNLayer3_we,0)
opt_params_GNN3_wu = jnp.mean(pastTenParams_GNNLayer3_wu,0)
opt_params_GNN3_wv = jnp.mean(pastTenParams_GNNLayer3_wv,0)
opt_params_GNN4_b = jnp.mean(pastTenParams_GNNLayer4_b,0)
opt_params_GNN4_we = jnp.mean(pastTenParams_GNNLayer4_we,0)
opt_params_GNN4_wu = jnp.mean(pastTenParams_GNNLayer4_wu,0)
opt_params_GNN4_wv = jnp.mean(pastTenParams_GNNLayer4_wv,0)

#NOTE: If edited config.num_GNN_layers, need to edit line below (need to add/remove 'gnn_layer_N':{'b': opt_params_...., 'wu': ...})
opt_params_avg =  {'gnn_layer': {'b': opt_params_GNN1_b, 'we': opt_params_GNN1_we, 'wu': opt_params_GNN1_wu, 'wv': opt_params_GNN1_wv},'gnn_layer_1': {'b': opt_params_GNN2_b, 'we': opt_params_GNN2_we, 'wu': opt_params_GNN2_wu, 'wv': opt_params_GNN2_wv}, 'gnn_layer_2': {'b': opt_params_GNN3_b, 'we': opt_params_GNN3_we, 'wu': opt_params_GNN3_wu, 'wv': opt_params_GNN3_wv}, 'gnn_layer_3': {'b': opt_params_GNN4_b, 'we': opt_params_GNN4_we, 'wu': opt_params_GNN4_wu, 'wv': opt_params_GNN4_wv}, 'linear': {'b': opt_params_dense1_b, 'w': opt_params_dense1_w}, 'linear_1': {'b': opt_params_dense2_b, 'w': opt_params_dense2_w}}

#Save average of last 10 parameters
opt_params_avg_flattened = jax.tree_util.tree_flatten(opt_params_avg)
fileName_avg = f'optParamsAvg_{epochs}Epochs_{runName}.npy'
jnp.save(fileName_avg, opt_params_avg_flattened[0])

#Uncomment lines below to create plots & save them
'''
print('List of Plots Created & Saved: ')
plt.title('GNN Loss vs Epoch')
plt.plot(range(epochStoppedAt), val_loss[:epochStoppedAt], label = 'Validation Loss')
plt.plot(range(epochStoppedAt), train_loss[:epochStoppedAt], label = 'Training Loss' )
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
print('GNNLossPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.savefig('GNNLossPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.close()

plt.title('GNN Accuracy vs Epoch')
plt.plot(range(epochStoppedAt), val_accuracy[:epochStoppedAt], label = 'Validation')
plt.plot(range(epochStoppedAt), train_accuracy[:epochStoppedAt], label = 'Training' )
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
print('GNNAccuracyPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.savefig('GNNAccuracyPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.close()

plt.title('GNN Competition Accuracy vs Epoch')
plt.plot(range(epochStoppedAt), val_accComp[:epochStoppedAt], label = 'Validation')
plt.plot(range(epochStoppedAt), train_accComp[:epochStoppedAt], label = 'Training' )
plt.xlabel('Epoch')
plt.ylabel('Competition Accuracy')
plt.legend()
plt.show()
print('GNNCompAccuracyPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.savefig('GNNCompAccuracyPlot_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.close()

#Plot validation loss & training loss from 8th epoch onwards
plt.plot(range(epochStoppedAt)[8:epochStoppedAt], val_loss[8:epochStoppedAt], label='Validation Loss')
plt.plot(range(epochStoppedAt)[8:epochStoppedAt], train_loss[8:epochStoppedAt], label='Training Loss')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.title(f'GNN Training & Validation Loss, Epochs 8-{epochStoppedAt}')
plt.show()
print('GNNLossPlotPast8Epoch_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.savefig('GNNLossPlotPast8Epoch_721_2Dense2Graph_GraphFeature256NodeFeature256AvgMessage1000Epoch.jpg')
plt.close()
'''

In [None]:
#Compute accuracy (standard) on test set
acc = np.zeros(test_N)
accAverageParams = np.zeros(test_N)
for i, testVal in enumerate(test_set):
    (nodes_i, edges_i), yi = testVal
    nodes_i = nodes_i.numpy()
    edges_i = edges_i.numpy()
    yi = yi.numpy()
    xi = nodes_i,edges_i
    accuracy = accuracy_fn(opt_params, xi,yi)
    accuracyAverageParams = accuracy_fn(opt_params_avg, xi, yi)
    #print(accuracy)
    acc[i] = accuracy
    accAverageParams[i] = accuracyAverageParams

print(f'Overall Accuracy: {np.mean(acc)}')
wandb.run.summary["Test set accuracy (standard accuracy)"] = np.mean(acc)
print(f'Overall Accuracy, average of last 10 parameters: {np.mean(accAverageParams)}')
wandb.run.summary["Test set accuracy, average params(standard accuracy)"] =  np.mean(accAverageParams)

In [None]:
#Compute competition accuracy on test set
accCompetition = np.zeros(test_N)
accComp_avg = np.zeros(test_N)
topPredSet = np.empty(test_N)
topPredSet_avg = np.empty(test_N)
for i, testVal in enumerate(test_set):
    (nodes_i, edges_i), yi = testVal
    nodes_i = nodes_i.numpy()
    edges_i = edges_i.numpy()
    yi = yi.numpy()
    xi = nodes_i,edges_i
    accuracy, topPredSet[i] = competitionAccuracy_fn(opt_params, xi,yi)
    accuracy_avgParams, topPredSet_avg[i] = competitionAccuracy_fn(opt_params_avg, xi,yi)
    #print(accuracy)
    accCompetition[i] = accuracy
    accComp_avg[i] = accuracy_avgParams

print(f'Overall Accuracy (Competition): {np.mean(accCompetition)}')
wandb.run.summary["Test set accuracy (competition accuracy)"] = np.mean(accCompetition)
print(f'Overall Accuracy (Competition), used average of last 10 parameters: {np.mean(accComp_avg)}')
wandb.run.summary["Test set accuracy, average params(competition accuracy)"] = np.mean(accComp_avg)

#Print out how often the first 3 predictions were the best when computing competition accuracy
top3Pred_correctPercent = 100*((test_N - np.count_nonzero(topPredSet))/test_N)
top3Pred_correctPercent_avg = 100*((test_N - np.count_nonzero(topPredSet_avg))/test_N)
print(f'Top 3 predictions were best on average (competition accuracy, regular params): {top3Pred_correctPercent}%')
print(f'Top 3 predictions were best on average (competition accuracy, avg of last 10 params): {top3Pred_correctPercent_avg}%')


In [None]:
#Compute AUROC & Create ROC curve for each scent class - uses scikit-learn
test_yhat = np.empty((test_N, numClasses)) #create empty array to store predictions on test set
test_y = np.empty((test_N, numClasses))

for i, testVal in enumerate(test_set):
    (nodes_i, edges_i), yi = testVal
    nodes_i = nodes_i.numpy()
    edges_i = edges_i.numpy()
    yi = yi.numpy()
    xi = nodes_i,edges_i
    test_yhat[i] = model.apply(opt_params, xi)
    test_y[i] = yi

#Count how often each scent occurs
occurrences = np.zeros(numClasses)
for i, val in enumerate(data):
    (nodes_i, edges_i), yi = val
    for j in range(numClasses):
        occurrences[j] += yi[j]
    
aurocs_scikit = []
aurocs_omitUncommonClasses = []
for c in range(numClasses):
    if(np.count_nonzero(test_y[:,c]) == 0):
        print(f'Test set does not have any molecules with scent {scentClasses[c]}')
    else:
        ##Uncomment lines below to create ROC curves
        #fpr, tpr, thresholds = sklearn.metrics.roc_curve(test_y[:,c], test_yhat[:,c])
        #plt.plot(fpr, tpr, '-o', label='Trained Model')
        #plt.plot([0,1], [0, 1], label='Naive Classifier')
        #plt.ylabel('True Positive Rate')
        #plt.xlabel('False Positive Rate')
        #plt.title(f'ROC Curve for {scentClasses[c]}')
        #plt.legend()
        #plt.show()
        #plt.savefig(f'GNN_ROC_Curve_{scentClasses[c]}_{runName}.jpg')
        #plt.close()
        
        auroc = sklearn.metrics.roc_auc_score(test_y[:,c], test_yhat[:,c])
        aurocs_scikit.append(auroc)
        if(occurrences[c] >= 30):
            aurocs_omitUncommonClasses.append(auroc)
            print(f'Included {scentClasses[c]}')
        else:
            print(f'Omitted {scentClasses[c]}')
        print(f'AUROC for scent {scentClasses[c]}: {auroc}')

mean_AUROC = np.mean(aurocs_scikit)
mean_AUROC_omitUncommonScents = np.mean(aurocs_omitUncommonClasses)
print(f'Mean AUROC: {mean_AUROC}')
print(f'Mean AUROC (w/uncommon scent classes omitted): {mean_AUROC_omitUncommonScents}')
wandb.run.summary['Mean AUROC'] = mean_AUROC
wandb.run.summary['Mean AUROC w/uncommon scents omitted'] = mean_AUROC_omitUncommonScents

In [None]:
#Stop W&B run
run.finish()