In [1]:
# --- Standard libraries
import pickle as pkl
#import numpy as np
import heapq as hq
from icecream import ic
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- RDKit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = 300,300
# --- Bokeh
from bokeh.io import output_notebook, show, export_svg
from bokeh.layouts import row
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN_atom import GNN
from utils import bokeh_spectra, calculate_rse, count_funct_group

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def pred_spec(model, index, test_dataset):
    # --- Set the model to evaluation mode
    model.eval()

    # --- Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])
    #print(batch.atom_num)

    # --- Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)
    
    # ---  the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.spectrum.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum, true_spectrum

### Coronene

#### Set model parameters

In [3]:
num_tasks = 200
num_layers = 4
emb_dim = 18
in_channels = [int(emb_dim), 512, 256, 128]
out_channels = [512, 256, 128, 64]
gnn_type = 'gcn'
heads = int(1)
drop_ratio = 0.25
graph_pooling = 'mean'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_atom_test.pt'

#### Load the saved model

In [4]:
# --- Load the saved model
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name, map_location='cpu'))

<All keys matched successfully>

#### Load the test data

In [5]:
test_dataset = torch.load("./processed/atom_test_dataset.pt")
print(f'Length of test dataset: {len(test_dataset)}')

Length of test dataset: 795


#### Use the model to predict the test data

In [6]:
# --- Create empty dictionary to contain spectra
predict = {}
true = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    predict[index], true[index] = pred_spec(model, index, test_dataset)

# --- Parse spectra into dictionary
model_dict = [predict, true]

name = 'spectra_ml_atom_test.pkl'

with open('spectra_results/' + name, "wb") as file:
    pkl.dump(model_dict, file)

#### View predictions

In [7]:
file = open('spectra_results/' + name, 'rb')
data = pkl.load(file)

predict = data[0]
true = data[1]

#### Perform analysis of predictions

In [8]:
wasser = []
mse = []
rse = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(predict[x], true[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(predict[x], true[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(predict[x], true[x])
    rse.append(rse_temp)

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.06591194126662177
Average MSE = 0.029256331380305995
Average RSE = 0.16570095091775666


In [9]:
rank_rse = hq.nsmallest(795, rse)

rank_graph = []

for x in range(795):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(-1, -6, -1):
    print(f'RSE = {rank_rse[x]:.3f}, graph number = {rank_graph[x]}')

The 5 best RSE values are:
RSE = 0.078, graph number = 482
RSE = 0.083, graph number = 13
RSE = 0.084, graph number = 283
RSE = 0.087, graph number = 161
RSE = 0.088, graph number = 30

The 5 worst RSE values are:
RSE = 0.459, graph number = 303
RSE = 0.444, graph number = 522
RSE = 0.428, graph number = 510
RSE = 0.420, graph number = 169
RSE = 0.414, graph number = 520


#### View and compare predictions

In [10]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(predict[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(predict[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)

In [11]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(predict[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(predict[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)

In [12]:
graph = rank_graph[-5]
p = bokeh_spectra(predict[graph], true[graph])
show(p)