In [9]:
# --- Standard libraries
import pickle as pkl
import numpy as np
import heapq as hq
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- RDKit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = 300,300
# --- Bokeh
from bokeh.io import output_notebook, show, export_svg
from bokeh.models import Label
from bokeh.layouts import row
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN import GNN
from utils.model import get_spec_prediction
from utils.utils import bokeh_spectra, calculate_rse, bokeh_hist

In [11]:
# --- Define GNN properties
num_tasks = 300
num_layers = 5
in_channels = [9, 60, 60, 60, 60]
out_channels = [60, 60, 60, 60, 300]
gnn_type = 'gcn'
heads = 1
drop_ratio = 0.0
graph_pooling = 'mean'

device = "cpu"

model_name = 'model_' + gnn_type + '_nocharge.pt'

In [12]:
num_tasks = 300

model = GNN(
    num_tasks,
    num_layers,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))

  model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))


<All keys matched successfully>

In [13]:
test_dataset = torch.load("./processed/test_dataset_nocharge.pt")
test_dataset

  test_dataset = torch.load("./processed/test_dataset_nocharge.pt")


XASDataset(6292)

In [14]:
# --- Create empty dictionary to contain spectra
pred = {}
tru = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    pred[index], tru[index] = get_spec_prediction(model, index, test_dataset, device)

# --- Parse spectra into dictionary
model_dict = [pred, tru]

# --- Save prediction results to file
with open("ML_models/results/" + model_name, "wb") as file:
    pkl.dump(model_dict, file)
    
# --- Get 
predict = model_dict[0]
true = model_dict[1]

In [15]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(true[x], predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(true[x], predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(true[x], predict[x])
    rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(true[x], predict[x])
    spear.append(spear_temp[0])

ave_wasser = sum(wasser) / len(wasser)
ave_spear = sum(spear) / len(spear)
ave_mse = sum(mse) / len(mse)
ave_rse = sum(rse) / len(rse)

print(f"Average Wasserstein distance = {ave_wasser}")
print(f'Average spearman correlation coefficient = {ave_spear}')
print(f"Average MSE = {ave_mse}")
print(f'Average RSE = {ave_rse}')

Average Wasserstein distance = 0.02739413015431912
Average spearman correlation coefficient = 0.7390693123497473
Average MSE = 0.0055660707876086235
Average RSE = 0.07123085856437683


In [16]:
rank_rse = hq.nsmallest(6292, rse)

rank_graph = []

for x in range(6292):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.021, graph number = 3602
RSE = 0.023, graph number = 3695
RSE = 0.024, graph number = 1771
RSE = 0.025, graph number = 1201
RSE = 0.027, graph number = 1978

The 5 worst RSE values are:
RSE = 0.519, graph number = 820
RSE = 0.385, graph number = 663
RSE = 0.229, graph number = 4719
RSE = 0.217, graph number = 2857
RSE = 0.216, graph number = 959


In [23]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(predict[rank_graph[6000]], true[rank_graph[6000]])
p = row(p1, p2, p3)
show(p)