In [21]:
# --- Standard libraries
import os.path as osp
import heapq as hq
import matplotlib.pyplot as plt
import numpy as np
# --- PyTorch
import torch
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
from utils import pred_spec, calculate_rse, bokeh_spectra
# --- Bokeh
from bokeh.io import output_notebook, show, export_png
from bokeh.layouts import row
from bokeh.models import Label
output_notebook()
# --- XASNet
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet
from XASNet.utils import GraphDataProducer
from XASNet.utils.visualisation import plot_prediction
from XASNet.utils.rse import rse_predictions, rse_loss, rse_histogram
from utils import bokeh_hist
torch.__version__

'2.4.0+cu121'

In [39]:
# --- Load in the test dataset
test_dataset = torch.load('./processed/test_xasnet_circ.pt')

  test_dataset = torch.load('./processed/test_xasnet_circ.pt')


In [40]:
# --- Print details of the dataset
print(f'Number of graphs: {len(test_dataset)}')
print('')

# --- Print details of the first molecule/graph in the dataset
data = test_dataset[2]

print(data)
print('------------')
print(f'Molecule index: {data.idx}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of graphs: 91

Data(x=[59, 15], edge_index=[2, 156], edge_attr=[156, 5], spectrum=[200], idx=[1], smiles='c12[c:1]3[c:2]4[c:3]([c:4]([H:59])[c:5]1[H:60])[c:38]([H:72])[c:39]1[c:40]5[c:9]4[c:8]4[c:7]6[c:6]3[c:24]3[c:23]([c:22]2[H:73])[c:31]([H:74])[c:32]([H:75])[c:33]2[c:25]3[c:26]3[c:10]6[c:11]6[c:12]7[c:13]4[c:42]4[c:41]5[c:49]([c:48]([H:67])[c:47]1[H:66])[c:50]([H:68])[c:51]1[c:43]4[c:44]4[c:17]7[c:16]5[c:15]7[c:14]6[c:28]6[c:27]3[c:35]([c:34]2[H:76])[c:36]([H:64])[c:37]([C:54](=[O:58])[H:71])[c:29]6[c:30]([H:63])[c:18]7[C:19]2([H:61])[C:20]([H:62])([c:21]5[c:46]([H:65])[c:45]4[c:53]([H:70])[c:52]1[C:55](=[O:57])[H:69])[O:56]2')
------------
Molecule index: tensor([35])
Number of nodes: 59
Number of edges: 156
Average node degree: 2.64
Has isolated nodes: False
Has self loops: False
Is undirected: True


##### XASNet GNN

In [41]:
# --- Set device for model to run on
device = 'cpu'

# --- Set ML model parameters to match the loaded model
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [15, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final layers
    num_layers = 3, # number of layers
    heads = 0
).to(device)

# --- Location of the ML model
path_to_model = osp.join('./best_model/xasnet_gcn_model.pt')

# --- Check if there is an existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))
else:
    print('Model is not loaded.')

  xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))


##### View Predictions

In [42]:
predict = {}
true = {}

for data in enumerate(test_dataset):
    predict[data[0]], true[data[0]] = pred_spec(xasnet_gnn, data, test_dataset, graphnet=False)

In [43]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(true[x], predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(true[x], predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(true[x], predict[x])
    rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(true[x], predict[x])
    spear.append(spear_temp[0])

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f'Average spearman correlation coefficient = {sum(spear) / len(spear)}')
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.07013247682480286
Average spearman correlation coefficient = 0.9813529019544172
Average MSE = 0.008817651961510489
Average RSE = 0.049025699823767276


In [44]:
rank_rse = hq.nsmallest(48, rse)

rank_graph = []

for x in range(48):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.4f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.4f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.0270, graph number = 49
RSE = 0.0288, graph number = 47
RSE = 0.0324, graph number = 79
RSE = 0.0325, graph number = 28
RSE = 0.0327, graph number = 37

The 5 worst RSE values are:
RSE = 0.0493, graph number = 24
RSE = 0.0489, graph number = 46
RSE = 0.0489, graph number = 82
RSE = 0.0486, graph number = 75
RSE = 0.0481, graph number = 51


In [45]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(predict[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')

In [46]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(predict[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(predict[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')