In [47]:
# --- Standard libraries
import os.path as osp
import heapq as hq
import matplotlib.pyplot as plt
import numpy as np
# --- PyTorch
import torch
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
from utils import pred_spec, calculate_rse, bokeh_spectra
# --- Bokeh
from bokeh.io import output_notebook, show
from bokeh.layouts import row
from bokeh.models import Label
output_notebook()
# --- XASNet
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet
from XASNet.utils import GraphDataProducer
from XASNet.utils.visualisation import plot_prediction
from XASNet.utils.rse import rse_predictions, rse_loss, rse_histogram
from utils import bokeh_hist
torch.__version__

'2.4.0+cu121'

#### Coronene

In [48]:
# --- Load in the test dataset
test_dataset = torch.load('./processed/test_xasnet_cor.pt')

  test_dataset = torch.load('./processed/test_xasnet_cor.pt')


In [49]:
# --- Print details of the dataset
print(f'Number of graphs: {len(test_dataset)}')
print('')

# --- Print details of the first molecule/graph in the dataset
data = test_dataset[2]

print(data)
print('------------')
print(f'Molecule index: {data.idx}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of graphs: 49

Data(x=[33, 15], edge_index=[2, 82], edge_attr=[82, 5], spectrum=[200], idx=[1], smiles='c12[c:2]3[c:3]([H:33])[c:6]([H:34])[c:7]4[c:5]1[c:9]1[c:13]5[c:17]6[c:21]2[C:23]([O:30][H:42])([C:22]([C:25](=[O:26])[O:27][H:40])([H:46])[C:20]([H:45])=[C:19]6[C:18]([H:44])=[C:16]([H:43])[C:15]52[C:14]([H:38])([C:12]5([H:37])[C:11]1([C:10]([H:36])=[C:8]4[H:35])[O:31]5)[O:32]2)[C:24]([H:39])=[C:1]3[C:4](=[O:28])[O:29][H:41]')
------------
Molecule index: tensor([138])
Number of nodes: 33
Number of edges: 82
Average node degree: 2.48
Has isolated nodes: False
Has self loops: False
Is undirected: True


##### XASNet GNN

In [52]:
# --- Set device for model to run on
device = 'cpu'

# --- Set ML model parameters to match the loaded model
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [15, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final layers
    num_layers = 3, # number of layers
    heads = 0
).to(device)

# --- Location of the ML model
path_to_model = osp.join('./best_model/xasnet_gcn_model.pt')

# --- Check if there is an existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))
else:
    print('Model is not loaded.')

  xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))


##### XASNet GAT

In [65]:
# --- Set device for model to run on
device = 'cpu'

# --- Create the type of ML model you want to run
xasnet_gat = XASNet_GAT(
    node_features_dim=15,
    in_channels=[128, 128, 128, 128],
    out_channels=[128, 128, 128, 400],
    targets=200,
    n_layers=4,
    n_heads=3,
    gat_type='gatv2_custom',
    use_residuals=True,
    use_jk=True
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model/xasnet_gat_model.pt')

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gat.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

  xasnet_gat.load_state_dict(torch.load(path_to_model))


##### XASNet GraphNet

In [87]:
# --- Set device for model to run on
device = 'cpu'

# --- Create the type of ML model you want to run
xasnet_graphnet = XASNet_GraphNet(
    node_dim=15,
    edge_dim=5,
    hidden_channels=512,
    out_channels=200,
    gat_hidd=512,
    gat_out=200,
    n_layers=3,
    n_targets=200
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model/xasnet_graphnet.pt')

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_graphnet.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


##### View Predictions

In [88]:
cor_predict = {}
cor_true = {}

for data in enumerate(test_dataset):
    cor_predict[data[0]], cor_true[data[0]] = pred_spec(xasnet_graphnet, data, test_dataset, graphnet=True)

In [89]:
cor_wasser = []
cor_mse = []
cor_rse = []
cor_spear = []

for x in range(len(cor_predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(cor_true[x], cor_predict[x])
    cor_wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(cor_true[x], cor_predict[x])
    cor_mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(cor_true[x], cor_predict[x])
    cor_rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(cor_true[x], cor_predict[x])
    cor_spear.append(spear_temp[0])

print(f"Average Wasserstein distance = {sum(cor_wasser) / len(cor_wasser)}")
print(f'Average spearman correlation coefficient = {sum(cor_spear) / len(cor_spear)}')
print(f"Average MSE = {sum(cor_mse) / len(cor_mse)}")
print(f'Average RSE = {sum(cor_rse) / len(cor_rse)}')

Average Wasserstein distance = 48.1359417738496
Average spearman correlation coefficient = -0.04781634925488523
Average MSE = 3553.1473724029875
Average RSE = -1.7302225146974837


In [90]:
cor_rank_rse = hq.nsmallest(49, cor_rse)

cor_rank_graph = []

for x in range(48):
    cor_rank_idx = cor_rse.index(cor_rank_rse[x])
    cor_rank_graph.append(cor_rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {cor_rank_rse[x]:.4f}, graph number = {cor_rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {cor_rank_rse[x]:.4f}, graph number = {cor_rank_graph[x]}')

The 5 best RSE values are:
RSE = -2.0367, graph number = 50
RSE = -1.9302, graph number = 28
RSE = -1.9094, graph number = 3
RSE = -1.8989, graph number = 41
RSE = -1.8963, graph number = 26

The 5 worst RSE values are:
RSE = -1.7071, graph number = 80
RSE = -1.7085, graph number = 49
RSE = -1.7092, graph number = 16
RSE = -1.7137, graph number = 46
RSE = -1.7142, graph number = 74


In [91]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cor_predict[cor_rank_graph[0]], cor_true[cor_rank_graph[0]])
p2 = bokeh_spectra(cor_predict[cor_rank_graph[1]], cor_true[cor_rank_graph[1]])
p3 = bokeh_spectra(cor_predict[cor_rank_graph[2]], cor_true[cor_rank_graph[2]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')

In [92]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(cor_predict[cor_rank_graph[-1]], cor_true[cor_rank_graph[-1]])
p2 = bokeh_spectra(cor_predict[cor_rank_graph[-2]], cor_true[cor_rank_graph[-2]])
p3 = bokeh_spectra(cor_predict[cor_rank_graph[-3]], cor_true[cor_rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')

#### Circumcoronene

In [71]:
# --- Load in the test dataset
test_dataset = torch.load('./processed/test_xasnet_circ.pt')

  test_dataset = torch.load('./processed/test_xasnet_circ.pt')


In [59]:
# --- Set device for model to run on
device = 'cpu'

# --- Set ML model parameters to match the loaded model
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [15, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final layers
    num_layers = 3, # number of layers
    heads = 0
).to(device)

# --- Location of the ML model
path_to_model = osp.join('./best_model/xasnet_gcn_model.pt')

# --- Check if there is an existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))
else:
    print('Model is not loaded.')

  xasnet_gnn.load_state_dict(torch.load(path_to_model, map_location=device))


In [82]:
circ_predict = {}
circ_true = {}

for data in enumerate(test_dataset):
    circ_predict[data[0]], circ_true[data[0]] = pred_spec(xasnet_gat, data, test_dataset, graphnet=False)

In [83]:
circ_wasser = []
circ_mse = []
circ_rse = []
circ_spear = []

for x in range(len(circ_predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(circ_true[x], circ_predict[x])
    circ_wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(circ_true[x], circ_predict[x])
    circ_mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(circ_true[x], circ_predict[x])
    circ_rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(circ_true[x], circ_predict[x])
    circ_spear.append(spear_temp[0])

print(f"Average Wasserstein distance = {sum(circ_wasser) / len(circ_wasser)}")
print(f'Average spearman correlation coefficient = {sum(circ_spear) / len(circ_spear)}')
print(f"Average MSE = {sum(circ_mse) / len(circ_mse)}")
print(f'Average RSE = {sum(circ_rse) / len(circ_rse)}')

Average Wasserstein distance = 0.07344967306536562
Average spearman correlation coefficient = 0.970696734451328
Average MSE = 0.010760290860351953
Average RSE = 0.05466330280670753


In [84]:
circ_rank_rse = hq.nsmallest(91, circ_rse)

circ_rank_graph = []

for x in range(91):
    circ_rank_idx = circ_rse.index(circ_rank_rse[x])
    circ_rank_graph.append(circ_rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {circ_rank_rse[x]:.4f}, graph number = {circ_rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {circ_rank_rse[x]:.4f}, graph number = {circ_rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.0331, graph number = 47
RSE = 0.0352, graph number = 79
RSE = 0.0366, graph number = 49
RSE = 0.0384, graph number = 37
RSE = 0.0387, graph number = 69

The 5 worst RSE values are:
RSE = 0.0847, graph number = 77
RSE = 0.0745, graph number = 2
RSE = 0.0728, graph number = 44
RSE = 0.0709, graph number = 42
RSE = 0.0702, graph number = 62


In [85]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(circ_predict[circ_rank_graph[0]], circ_true[circ_rank_graph[0]])
p2 = bokeh_spectra(circ_predict[circ_rank_graph[1]], circ_true[circ_rank_graph[1]])
p3 = bokeh_spectra(circ_predict[circ_rank_graph[2]], circ_true[circ_rank_graph[2]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')

In [86]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(circ_predict[circ_rank_graph[-1]], circ_true[circ_rank_graph[-1]])
p2 = bokeh_spectra(circ_predict[circ_rank_graph[-2]], circ_true[circ_rank_graph[-2]])
p3 = bokeh_spectra(circ_predict[circ_rank_graph[-3]], circ_true[circ_rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')