In [1]:
# --- Standard libraries
import pickle as pkl
#import numpy as np
import heapq as hq
from icecream import ic
# --- PyTorch
import torch
from torch_geometric.data import Batch
# --- Bokeh
from bokeh.io import output_notebook, show, export_svg
from bokeh.layouts import row
output_notebook()
# --- Science python
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance, spearmanr
# --- Modules from local files
from GNN_atom import GNN
from utils import bokeh_spectra, calculate_rse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def pred_spec(model, index, test_dataset):
    # --- Set the model to evaluation mode
    model.eval()

    # --- Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])
    #print(batch.atom_num)

    # --- Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)
    
    # ---  the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.spectrum.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum, true_spectrum

### Coronene

#### Set model parameters

In [3]:
num_tasks = 200
num_layers = 3
emb_dim = 15
in_channels = [int(emb_dim), 256, 128]
out_channels = [256, 128, 64]
gnn_type = 'gcn'
heads = int(1)
drop_ratio = 0.0
graph_pooling = 'mean'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'model_gnn_atom_test.pt'

#### Load the saved model

In [4]:
# --- Load the saved model
num_tasks = 200

model = GNN(
    num_tasks,
    num_layers,
    emb_dim,
    in_channels,
    out_channels,
    gnn_type,
    heads,
    drop_ratio,
    graph_pooling,
)
model = model.to(device)
model.load_state_dict(torch.load("ML_models/" + model_name))

<All keys matched successfully>

#### Load the test data

In [5]:
test_dataset = torch.load("./processed/atom_test_dataset.pt")

#### Use the model to predict the test data

In [6]:
# --- Create empty dictionary to contain spectra
predict = {}
true = {}

# --- Loop over all molecules in test dateset and assign to dictionary index
for index in range(len(test_dataset)):
    predict[index], true[index] = pred_spec(model, index, test_dataset)

# --- Parse spectra into dictionary
model_dict = [predict, true]

name = 'spectra_ml_atom_test.pkl'

with open('spectra_results/' + name, "wb") as file:
    pkl.dump(model_dict, file)

#### View predictions

In [7]:
file = open('spectra_results/' + name, 'rb')
data = pkl.load(file)

predict = data[0]
true = data[1]

#### Perform analysis of predictions

In [8]:
wasser = []
mse = []
rse = []

for x in range(len(predict)):
    # Wasserstein metric
    wass_temp = wasserstein_distance(predict[x], true[x])
    wasser.append(wass_temp)
    # Mean squared error
    mse_temp = mean_squared_error(predict[x], true[x])
    mse.append(mse_temp)
    # RSE
    rse_temp = calculate_rse(predict[x], true[x])
    rse.append(rse_temp)

print(f"Average Wasserstein distance = {sum(wasser) / len(wasser)}")
print(f"Average MSE = {sum(mse) / len(mse)}")
print(f'Average RSE = {sum(rse) / len(rse)}')

Average Wasserstein distance = 0.06313094644022427
Average MSE = 0.030546438370142225
Average RSE = 0.1683145725442851


In [9]:
five_best = hq.nsmallest(5, rse)
five_worst = hq.nlargest(5, rse)

best = []
worst = []

for x in range(5):
    best_idx = rse.index(five_best[x])
    best.append(best_idx)

    worst_idx = rse.index(five_worst[x])
    worst.append(worst_idx)

print('The 5 best RSE values are:')
for x in range(5):
    print(f'RSE = {five_best[x]:.3f}, graph number = {best[x]}')

print('')
print('The 5 worst RSE values are:')
for x in range(5):
    print(f'RSE = {five_worst[x]:.3f}, graph number = {worst[x]}')

The 5 best RSE values are:
RSE = 0.057, graph number = 622
RSE = 0.074, graph number = 467
RSE = 0.074, graph number = 364
RSE = 0.077, graph number = 453
RSE = 0.078, graph number = 295

The 5 worst RSE values are:
RSE = 0.437, graph number = 607
RSE = 0.434, graph number = 303
RSE = 0.432, graph number = 522
RSE = 0.422, graph number = 169
RSE = 0.417, graph number = 43


#### View and compare predictions

In [10]:
# --- Plot best spectra prediction
p1 = bokeh_spectra(predict[best[0]], true[best[0]])
p2 = bokeh_spectra(predict[best[1]], true[best[1]])
p3 = bokeh_spectra(predict[best[2]], true[best[2]])
p = row(p1, p2, p3)
show(p)

In [11]:
# --- Plot worst spectra prediction
p1 = bokeh_spectra(predict[worst[0]], true[worst[0]])
p2 = bokeh_spectra(predict[worst[1]], true[worst[1]])
p3 = bokeh_spectra(predict[worst[2]], true[worst[2]])
p = row(p1, p2, p3)
show(p)