In [105]:
# --- Standard libraries
import pickle as pkl
import numpy as np
import heapq as hq
from icecream import ic
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- RDKit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = 300,300
# --- Bokeh
from bokeh.plotting import figure
from bokeh.io import output_notebook, show, export_png
from bokeh.layouts import row
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN_mol import GNN
from utils import bokeh_spectra, calculate_rse, bokeh_hist

In [106]:
def pred_spec(model, index, test_dataset):
    # --- Set the model to evaluation mode
    model.eval()

    # --- Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])

    # --- Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)

    # --- Access the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.spectrum.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum, true_spectrum

### Coronene

#### Set model value

In [107]:
num_tasks = 200
num_layers = 3
emb_dim = 15
in_channels = [int(emb_dim), 64, 128]
out_channels = [64, 128, 256]
gnn_type = "gcn"
heads = int(1)
drop_ratio = 0.75
graph_pooling = "mean"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_13.pt'

#### Load the saved model

In [108]:
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))

<All keys matched successfully>

#### Load the test data

In [109]:
test_dataset = torch.load("./processed/mol_test_dataset.pt")

In [110]:
test_dataset

XASDataset(48)

#### Use model to predict from test data

In [111]:
# --- Create empty dictionary to contain spectra
predict = {}
true = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    predict[index], true[index] = pred_spec(model, index, test_dataset)

# --- Parse spectra into dictionary
model_dict = [predict, true]

# --- Save prediction results to file
name = 'spectra_ml_best.pkl'

with open("spectra_results/" + name, "wb") as file:
    pkl.dump(model_dict, file)

#### View predictions

In [112]:
file = open('spectra_results/' + name, 'rb')
data = pkl.load(file)

predict = data[0]
true = data[1]

#### Perform analysis of predictions

In [113]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(true[x], predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(true[x], predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(true[x], predict[x])
    rse.append(rse_temp)
    # Spearman coefficient
    spear_temp = spearmanr(true[x], predict[x])
    spear.append(spear_temp[0])

ave_wasser = sum(wasser) / len(wasser)
ave_spear = sum(spear) / len(spear)
ave_mse = sum(mse) / len(mse)
ave_rse = sum(rse) / len(rse)

print(f"Average Wasserstein distance = {ave_wasser}")
print(f'Average spearman correlation coefficient = {ave_spear}')
print(f"Average MSE = {ave_mse}")
print(f'Average RSE = {ave_rse}')

Average Wasserstein distance = 0.04126785019559292
Average spearman correlation coefficient = 0.9578007887697195
Average MSE = 0.00677007495381068
Average RSE = 0.03750722214391836


In [114]:
rank_rse = hq.nsmallest(37, rse)

rank_graph = []

for x in range(37):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.020, graph number = 42
RSE = 0.022, graph number = 7
RSE = 0.024, graph number = 21
RSE = 0.025, graph number = 13
RSE = 0.025, graph number = 22

The 5 worst RSE values are:
RSE = 0.041, graph number = 0
RSE = 0.041, graph number = 11
RSE = 0.041, graph number = 41
RSE = 0.041, graph number = 34
RSE = 0.040, graph number = 33


#### View and compare predictions

In [115]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(predict[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_best_RSE.png')

In [116]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(predict[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(predict[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_worst_RSE.png')

In [117]:
# --- Plot overview spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p1.title.text = 'Best RSE = 0.020'
p1.title.align = 'center'
p1.title.text_font_size = '24px'
p2 = bokeh_spectra(predict[rank_graph[24]], true[rank_graph[24]])
p2.title.text = 'RSE = 0.038'
p2.title.align = 'center'
p2.title.text_font_size = '24px'
p3 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p3.title.text = 'Worst RSE = 0.040'
p3.title.align = 'center'
p3.title.text_font_size = '24px'
p = row(p1, p3)
show(p)
export_png(p, filename='GO_mol_RSE.png')

'd:\\github\\GO_molecule_GNN\\GO_mol_RSE.png'

In [118]:
bins = np.linspace(0.015, 0.045, 40)
hist, edges = np.histogram(rank_rse, density=True, bins=bins)
p_hist = bokeh_hist(hist, edges, ave_rse)
show(p_hist)
export_png(p_hist, filename='GO_mol_hist.png')

'd:\\github\\GO_molecule_GNN\\GO_mol_hist.png'

### Circumcoronene

#### Load the circumcoronene dataset

In [119]:
circum_dataset = torch.load('./processed/mol_cir_dataset.pt')

In [120]:
# --- Show details of the dataset
print(circum_dataset)
print('-------------')
print(f'Number of graphs: {len(circum_dataset)}')
print(f'Number of features: {circum_dataset.num_features}')
print(f'Number of classes: {circum_dataset.num_classes}')
print('')

# --- Show details of the first molecule/graph in the dataset
circum_data = circum_dataset[1]

print(circum_data)
print('---------------')
print(f'Number of nodes: {circum_data.num_nodes}')
print(f'Number of edges: {circum_data.num_edges}')
print(f'Average node degree: {circum_data.num_edges / circum_data.num_nodes:.2f} ')
print(f'Has isolated nodes: {circum_data.has_isolated_nodes()}')
print(f'Has self loops: {circum_data.has_self_loops()}')
print(f'Is undirected: {circum_data.is_undirected()}')

XASDataset(91)
-------------
Number of graphs: 91
Number of features: 15
Number of classes: 0

Data(x=[60, 15], edge_index=[2, 160], edge_attr=[160, 6], spectrum=[200], idx=[1], smiles='c12[c:1]3[c:2]4[c:3]([cH:4][cH:5]1)[cH:38][c:39]1[c:40]5[c:9]4[c:8]4[c:7]6[c:6]3[c:24]3[c:23]([cH:22]2)[cH:31][cH:32][c:33]2[c:25]3[c:26]3[c:10]6[c:11]6[c:12]7[c:13]4[c:42]4[c:41]5[c:49]([cH:48][cH:47]1)=[CH:50][C:51]15[c:43]4[c:44]4[c:17]7[c:16]7[c:15]8[c:14]6[c:28]6[c:27]3[C:35]([OH:58])([CH:34]=2)[CH:36]2[CH:37]([c:29]6[cH:30][c:18]8[cH:19][cH:20][c:21]7[c:46]([C:54](=[O:56])[OH:57])[c:45]4[CH2:53][CH:52]1[O:59]5)[O:55]2')
---------------
Number of nodes: 60
Number of edges: 160
Average node degree: 2.67 
Has isolated nodes: False
Has self loops: False
Is undirected: True


#### Load the model

In [121]:
num_tasks = 200
num_layers = 3
emb_dim = 15
in_channels = [int(emb_dim), 64, 128]
out_channels = [64, 128, 256]
gnn_type = "gcn"
heads = int(1)
drop_ratio = 0.75
graph_pooling = "mean"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_13.pt'

In [122]:
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location=device))

<All keys matched successfully>

In [123]:
cir_predict = {}
cir_true = {}

for index in range(len(circum_dataset)):
    cir_predict[index], cir_true[index] = pred_spec(model, index, circum_dataset)

# --- Parse spectra into dictionary
cir_model_dict = [cir_predict, cir_true]

name = 'cir_spectra_ml_best.pkl'

with open('spectra_results/' + name, 'wb') as file:
    pkl.dump(cir_model_dict, file)

In [124]:
file = open('spectra_results/' + name, 'rb')
cir_data = pkl.load(file)

cir_predict = cir_data[0]
cir_true = cir_data[1]

In [125]:
wasser = []
mse = []
rse = []
spear = []

for x in range(len(cir_predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(cir_true[x], cir_predict[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(cir_true[x], cir_predict[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(cir_true[x], cir_predict[x])
    rse.append(rse_temp)
    # Spearman
    spear_temp = spearmanr(cir_true[x], cir_predict[x])
    spear.append(spear_temp[0])

cir_ave_wasser = sum(wasser) / len(wasser)
cir_ave_spear = sum(spear) / len(spear)
cir_ave_mse = sum(mse) / len(mse)
cir_ave_rse = sum(rse) / len(rse)

print(f"Average Wasserstein distance = {cir_ave_wasser}")
print(f'Average Spearman correlation coefficiant = {cir_ave_spear}')
print(f"Average MSE = {cir_ave_mse}")
print(f'Average RSE = {cir_ave_rse}')

Average Wasserstein distance = 0.058077461228015555
Average Spearman correlation coefficiant = 0.9834619876485918
Average MSE = 0.006628370189072837
Average RSE = 0.042730185640862936


In [126]:
rank_rse = hq.nsmallest(91, rse)

rank_graph = []

for x in range(91):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.020, graph number = 49
RSE = 0.021, graph number = 16
RSE = 0.024, graph number = 79
RSE = 0.025, graph number = 28
RSE = 0.026, graph number = 37

The 5 worst RSE values are:
RSE = 0.073, graph number = 34
RSE = 0.068, graph number = 77
RSE = 0.062, graph number = 44
RSE = 0.060, graph number = 13
RSE = 0.060, graph number = 58


In [127]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cir_predict[rank_graph[0]], cir_true[rank_graph[0]])
p2 = bokeh_spectra(cir_predict[rank_graph[1]], cir_true[rank_graph[1]])
p3 = bokeh_spectra(cir_predict[rank_graph[2]], cir_true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_cir_best_RSE.png')

In [128]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cir_predict[rank_graph[-1]], cir_true[rank_graph[-1]])
p2 = bokeh_spectra(cir_predict[rank_graph[-2]], cir_true[rank_graph[-2]])
p3 = bokeh_spectra(cir_predict[rank_graph[-3]], cir_true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='GO_mol_cir_worst_RSE.png')

In [129]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(cir_predict[rank_graph[0]], cir_true[rank_graph[0]])
p1.title.text = 'Best RSE = 0.020'
p1.title.align = 'center'
p1.title.text_font_size = '24px'
p2 = bokeh_spectra(cir_predict[rank_graph[44]], cir_true[rank_graph[44]])
p3 = bokeh_spectra(cir_predict[rank_graph[-1]], cir_true[rank_graph[-1]])
p3.title.text = 'Worst RSE = 0.060'
p3.title.align = 'center'
p3.title.text_font_size = '24px'
p = row(p1, p3)
show(p)
export_png(p, filename='GO_cir_RSE.png')

'd:\\github\\GO_molecule_GNN\\GO_cir_RSE.png'

In [130]:
bins = np.linspace(0.015, 0.065, 40)
cir_hist, cir_edges = np.histogram(rank_rse, density=True, bins=bins)
p_hist = bokeh_hist(cir_hist, cir_edges, ave_rse)
show(p_hist)
export_png(p_hist, filename='GO_cir_hist.png')

'd:\\github\\GO_molecule_GNN\\GO_cir_hist.png'

In [133]:
from bokeh.plotting import figure
from bokeh.models import SingleIntervalTicker, LinearAxis, NumeralTickFormatter, Span
from bokeh.palettes import HighContrast3

p = figure(
        x_axis_label = 'RSE value', y_axis_label = 'Frequency',
        x_range = (cir_edges[0], cir_edges[-1]), y_range = (0, max(hist)+10),
        width = 500, height = 450,
        outline_line_color = 'black', outline_line_width = 2
    )

p.toolbar.logo = None
p.toolbar_location = None
p.min_border = 25

# --- x-axis settings
p.xaxis.ticker.desired_num_ticks = 3
p.xaxis.axis_label_text_font_size = '24px'
p.xaxis.major_label_text_font_size = '24px'
p.xaxis.major_tick_in = 0
p.xaxis.major_tick_out = 10
p.xaxis.minor_tick_out = 6
p.xaxis.major_tick_line_width = 2
p.xaxis.minor_tick_line_width = 2
p.xaxis.major_tick_line_color = 'black'
p.xaxis.minor_tick_line_color = 'black'
p.xaxis[0].ticker.desired_num_ticks = 3
# --- y-axis settings
p.yaxis.axis_label_text_font_size = '24px'
p.yaxis.major_label_text_font_size = '24px'
p.yaxis.major_tick_in = 0
p.yaxis.major_tick_out = 10
p.yaxis.major_tick_line_width = 2
p.yaxis.major_tick_line_color = 'black'
p.yaxis.minor_tick_line_color = None
p.yaxis.major_label_text_color = 'black'
# --- grid settings
p.grid.grid_line_color = 'grey'
p.grid.grid_line_alpha = 0.3
p.grid.grid_line_width = 1.5
p.grid.grid_line_dash = "dashed"

# --- Format x-axis
ticker = SingleIntervalTicker(interval=20)
xaxis = LinearAxis(ticker=ticker)
p.add_layout(xaxis, 'below')

# --- Plot data
# --- Add histogram
p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_color='skyblue', line_color='black')
p.quad(top=cir_hist, bottom=0, left=cir_edges[:-1], right=cir_edges[1:], fill_color='tomato', line_color='black')
# --- Add average line
vline = Span(location=ave_rse, dimension='height', line_color='darkblue', line_width=3, line_dash='dashed')
vline1 = Span(location=cir_ave_rse, dimension='height', line_color='darkred', line_width=3, line_dash='dashed')
p.renderers.extend([vline, vline1])

show(p)
export_png(p, filename='hist_comp.png')

'd:\\github\\GO_molecule_GNN\\hist_comp.png'