In [79]:
# --- Standard libraries
import pickle as pkl
import numpy as np
import heapq as hq
from icecream import ic
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- RDKit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = 300,300
# --- Bokeh
from bokeh.io import output_notebook, show, export_png
from bokeh.layouts import row
from bokeh.models import Label
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN_atom import GNN
from utils import bokeh_spectra, calculate_rse, count_funct_group, bokeh_hist

In [80]:
def pred_spec(model, index, test_dataset):
    # --- Set the model to evaluation mode
    model.eval()

    # --- Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])
    #print(batch.atom_num)

    # --- Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)
    
    # ---  the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.spectrum.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum, true_spectrum

### Coronene

#### Set model parameters

In [103]:
num_tasks = 200
num_layers = 3
emb_dim = 15
in_channels = [int(emb_dim), 64, 128]
out_channels = [64, 128, 256]
gnn_type = 'gcn'
heads = int(1)
drop_ratio = 0.2
graph_pooling = 'mean'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_atom_1.pt'

#### Load the saved model

In [104]:
# --- Load the saved model
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location='cpu'))

<All keys matched successfully>

#### Load the test data

In [105]:
test_dataset = torch.load("./processed/atom_test_dataset.pt")
print(f'Length of test dataset: {len(test_dataset)}')

Length of test dataset: 895


#### Use the model to predict the test data

In [106]:
# --- Create empty dictionary to contain spectra
predict = {}
true = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    predict[index], true[index] = pred_spec(model, index, test_dataset)

# --- Parse spectra into dictionary
model_dict = [predict, true]

name = 'spectra_ml_atom_test.pkl'

with open('spectra_results/' + name, "wb") as file:
    pkl.dump(model_dict, file)

#### View predictions

In [107]:
file = open('spectra_results/' + name, 'rb')
data = pkl.load(file)

predict = data[0]
true = data[1]

#### Perform analysis of predictions

In [108]:
wasser = []
mse = []
rse = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(predict[x], true[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(predict[x], true[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(predict[x], true[x])
    rse.append(rse_temp)

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.07697039921227654
Average MSE = 0.03160140264220955
Average RSE = 0.1760771114243605


In [109]:
rank_rse = hq.nsmallest(789, rse)

rank_graph = []

for x in range(789):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.074, graph number = 887
RSE = 0.076, graph number = 350
RSE = 0.077, graph number = 103
RSE = 0.079, graph number = 1
RSE = 0.086, graph number = 471

The 5 worst RSE values are:
RSE = 0.245, graph number = 829
RSE = 0.242, graph number = 284
RSE = 0.241, graph number = 88
RSE = 0.241, graph number = 6
RSE = 0.241, graph number = 852


In [110]:
rank_rse[450]

0.16213004111483556

#### View and compare predictions

In [121]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[443]], true[rank_graph[443]])
p2.legend.location = 'top_right'
p3 = bokeh_spectra(predict[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p2)
export_png(p2, filename='ave.png')

'/home/samjhall/github/GO_molecule_GNN/ave.png'

In [102]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(predict[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(predict[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)
#export_png(p, filename='worst.png')

In [78]:
bins = np.linspace(0.07, 0.5, 35)
hist, edges = np.histogram(rank_rse, density=True, bins=bins)
average_rse = sum(rse) / len(rse)
p_hist = bokeh_hist(hist, edges, average_rse, 2)
l1 = Label(x=0.19, y=11, x_units='data', y_units='data', text='RSE = 0.166', text_font_size='24px')
p_hist.add_layout(l1)

show(p_hist)
export_png(p_hist, filename='GO_atom_hist.png')

'/home/samjhall/github/GO_molecule_GNN/GO_atom_hist.png'