In [1]:
# --- Standard libraries
import pickle as pkl
import numpy as np
import heapq as hq
from icecream import ic
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- RDKit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = 300,300
# --- Bokeh
from bokeh.io import output_notebook, show, export_png
from bokeh.layouts import row
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN_mol import GNN
from utils import bokeh_spectra, calculate_rse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def pred_spec(model, index, test_dataset):
    # --- Set the model to evaluation mode
    model.eval()

    # --- Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])

    # --- Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)

    # --- Access the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.spectrum.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum, true_spectrum

### Coronene

#### Set model value

In [3]:
num_tasks = 200
num_layers = 3
emb_dim = 17
in_channels = [int(emb_dim), 256, 128]
out_channels = [256, 128, 64]
gnn_type = "gcn"
heads = int(1)
drop_ratio = 0.6
graph_pooling = "mean"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_test.pt'

#### Load the saved model

In [4]:
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))

<All keys matched successfully>

#### Load the test data

In [7]:
test_dataset = torch.load("./processed/mol_test_dataset.pt")

In [9]:
test_dataset

XASDataset(39)

#### Use model to predict from test data

In [10]:
# --- Create empty dictionary to contain spectra
predict = {}
true = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    predict[index], true[index] = pred_spec(model, index, test_dataset)

# --- Parse spectra into dictionary
model_dict = [predict, true]

# --- Save prediction results to file
name = 'spectra_ml_test.pkl'

with open("spectra_results/" + name, "wb") as file:
    pkl.dump(model_dict, file)

#### View predictions

In [11]:
file = open('spectra_results/' + name, 'rb')
data = pkl.load(file)

predict = data[0]
true = data[1]

#### Perform analysis of predictions

In [12]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(true[x], predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(true[x], predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(true[x], predict[x])
    rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(true[x], predict[x])
    spear.append(spear_temp[0])

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f'Average spearman correlation coefficient = {sum(spear) / len(spear)}')
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.04412306116082719
Average spearman correlation coefficient = 0.9686132922553833
Average MSE = 0.006252224500718529
Average RSE = 0.03613615984341291


In [13]:
rank_rse = hq.nsmallest(39, rse)

rank_graph = []

for x in range(39):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.020, graph number = 2
RSE = 0.024, graph number = 34
RSE = 0.025, graph number = 32
RSE = 0.025, graph number = 16
RSE = 0.025, graph number = 29

The 5 worst RSE values are:
RSE = 0.059, graph number = 21
RSE = 0.056, graph number = 11
RSE = 0.055, graph number = 30
RSE = 0.052, graph number = 8
RSE = 0.050, graph number = 13


#### View and compare predictions

In [14]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(predict[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)
export_png(p, filename='GO_mol_best_RSE.png')

Problem reading geckodriver versions: error sending request for url (https://raw.githubusercontent.com/SeleniumHQ/selenium/trunk/common/geckodriver/geckodriver-support.json): error trying to connect: tcp connect error: No connection could be made because the target machine actively refused it. (os error 10061). Using latest geckodriver version
Exception managing firefox: error sending request for url (https://github.com/mozilla/geckodriver/releases/latest): error trying to connect: tcp connect error: No connection could be made because the target machine actively refused it. (os error 10061)


'd:\\github\\GO_molecule_GNN\\GO_mol_best_RSE.png'

In [15]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(predict[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(predict[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
export_png(p, filename='GO_mol_worst_RSE.png')

'd:\\github\\GO_molecule_GNN\\GO_mol_worst_RSE.png'

### Circumcoronene

#### Load the circumcoronene dataset

In [16]:
circum_dataset = torch.load('./processed/mol_cir_dataset.pt')

In [17]:
# --- Show details of the dataset
print(circum_dataset)
print('-------------')
print(f'Number of graphs: {len(circum_dataset)}')
print(f'Number of features: {circum_dataset.num_features}')
print(f'Number of classes: {circum_dataset.num_classes}')
print('')

# --- Show details of the first molecule/graph in the dataset
circum_data = circum_dataset[1]

print(circum_data)
print('---------------')
print(f'Number of nodes: {circum_data.num_nodes}')
print(f'Number of edges: {circum_data.num_edges}')
print(f'Average node degree: {circum_data.num_edges / circum_data.num_nodes:.2f} ')
print(f'Has isolated nodes: {circum_data.has_isolated_nodes()}')
print(f'Has self loops: {circum_data.has_self_loops()}')
print(f'Is undirected: {circum_data.is_undirected()}')

XASDataset(91)
-------------
Number of graphs: 91
Number of features: 17
Number of classes: 0

Data(x=[60, 17], edge_index=[2, 160], edge_attr=[160, 6], spectrum=[200], idx=[1], smiles='c12[c:1]3[c:2]4[c:3]([cH:4][cH:5]1)[cH:38][c:39]1[c:40]5[c:9]4[c:8]4[c:7]6[c:6]3[c:24]3[c:23]([cH:22]2)[cH:31][cH:32][c:33]2[c:25]3[c:26]3[c:10]6[c:11]6[c:12]7[c:13]4[c:42]4[c:41]5[c:49]([cH:48][cH:47]1)=[CH:50][C:51]15[c:43]4[c:44]4[c:17]7[c:16]7[c:15]8[c:14]6[c:28]6[c:27]3[C:35]([OH:58])([CH:34]=2)[CH:36]2[CH:37]([c:29]6[cH:30][c:18]8[cH:19][cH:20][c:21]7[c:46]([C:54](=[O:56])[OH:57])[c:45]4[CH2:53][CH:52]1[O:59]5)[O:55]2')
---------------
Number of nodes: 60
Number of edges: 160
Average node degree: 2.67 
Has isolated nodes: False
Has self loops: False
Is undirected: True


#### Load the model

In [18]:
num_tasks = 200
num_layers = 3
emb_dim = 17
in_channels = [int(emb_dim), 256, 128]
out_channels = [256, 128, 64]
gnn_type = "gcn"
heads = int(1)
drop_ratio = 0.6
graph_pooling = "mean"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_test.pt'

In [19]:
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))

<All keys matched successfully>

In [20]:
cir_predict = {}
cir_true = {}

for index in range(len(circum_dataset)):
    cir_predict[index], cir_true[index] = pred_spec(model, index, circum_dataset)

# --- Parse spectra into dictionary
cir_model_dict = [cir_predict, cir_true]

name = 'cir_spectra_ml_test.pkl'

with open('spectra_results/' + name, 'wb') as file:
    pkl.dump(cir_model_dict, file)

In [21]:
file = open('spectra_results/' + name, 'rb')
cir_data = pkl.load(file)

cir_predict = cir_data[0]
cir_true = cir_data[1]

In [22]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(cir_predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(cir_true[x], cir_predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(cir_true[x], cir_predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(cir_true[x], cir_predict[x])
    rse.append(rse_temp)
    # Spearman
    spear_temp = spearmanr(cir_true[x], cir_predict[x])
    spear.append(spear_temp[0])

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f'Average Spearman correlation coefficiant = {sum(spear) / len(spear)}')
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.035738619635620625
Average Spearman correlation coefficiant = 0.9748158827313506
Average MSE = 0.005021439688423746
Average RSE = 0.04005142939412924


In [23]:
rank_rse = hq.nsmallest(91, rse)

rank_graph = []

for x in range(91):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.023, graph number = 47
RSE = 0.023, graph number = 41
RSE = 0.023, graph number = 79
RSE = 0.024, graph number = 3
RSE = 0.025, graph number = 31

The 5 worst RSE values are:
RSE = 0.069, graph number = 34
RSE = 0.062, graph number = 77
RSE = 0.061, graph number = 25
RSE = 0.061, graph number = 44
RSE = 0.054, graph number = 42


In [26]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cir_predict[rank_graph[0]], cir_true[rank_graph[0]])
p2 = bokeh_spectra(cir_predict[rank_graph[1]], cir_true[rank_graph[1]])
p3 = bokeh_spectra(cir_predict[rank_graph[2]], cir_true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)
export_png(p, filename='GO_mol_cir_best_RSE.png')

'd:\\github\\GO_molecule_GNN\\GO_mol_cir_best_RSE.png'

In [27]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cir_predict[rank_graph[-1]], cir_true[rank_graph[-1]])
p2 = bokeh_spectra(cir_predict[rank_graph[-2]], cir_true[rank_graph[-2]])
p3 = bokeh_spectra(cir_predict[rank_graph[-3]], cir_true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
export_png(p, filename='GO_mol_cir_worst_RSE.png')

'd:\\github\\GO_molecule_GNN\\GO_mol_cir_worst_RSE.png'