In [1]:
import heapq as hq
import torch
import optuna
from bokeh.io import output_notebook, show, export_svg
from bokeh.layouts import row
output_notebook()
from sklearn.metrics import mean_squared_error
from models import GNN_model
from utils.functions import get_spec_prediction, calculate_rse, bokeh_spectra

In [2]:
num_tasks = 200
num_layers = 3
hidden_channels = 64
in_channels = [hidden_channels] * (num_layers - 1)
in_channels.insert(0, 15)
out_channels = [hidden_channels] * num_layers
gnn_type = 'gcn'
heads = 1
drop_ratio = 0.3
graph_pooling = 'mean'

model  = GNN_model(num_tasks, num_layers, in_channels, out_channels,
                   gnn_type, heads, drop_ratio, graph_pooling)

In [3]:
model.load_state_dict(torch.load('./best_model.pth'))

  model.load_state_dict(torch.load('./best_model.pth'))


<All keys matched successfully>

In [4]:
test_dataset = torch.load('../processed/test_coronene_pyg.pt')
print(test_dataset[0])

Data(x=[31, 15], edge_index=[2, 80], edge_attr=[80, 5], spectrum=[200], idx=[1], smiles='C12=[C:2]3[C:1]([O:25][H:40])=[C:24]([H:42])[C:23]45[C:21]1([C:17]16[c:13]7[c:9]8[c:11]([c:12]([H:34])[c:14]([H:35])[c:15]7[C:16]([H:36])=[C:18]([H:37])[C:19]1([C:20]([H:38])=[C:22]4[H:39])[O:28]6)[C:10]([H:33])=[C:8]([H:32])[C:7]1([C:5]28[O:29]1)[C:6]([H:31])=[C:3]3[C:4](=[O:26])[O:27][H:41])[O:30]5')


  test_dataset = torch.load('../processed/test_coronene_pyg.pt')


In [5]:
pred = {}
true = {}
device = 'cpu'

for index, data in enumerate(test_dataset):
    pred[index], true[index] = get_spec_prediction(model, data, device)

In [6]:
mse = []
rse = []

for x in range(len(pred)):
    mse_temp = mean_squared_error(true[x], pred[x])
    mse.append(mse_temp)
    rse_temp = calculate_rse(pred[x], true[x])
    rse.append(rse_temp)

ave_mse = sum(mse) / len(mse)
ave_rse = sum(rse) / len(rse)

print(f'Average MSE = {ave_mse}')
print(f'Average RSE = {ave_rse}')

Average MSE = 0.006963788066059351
Average RSE = 0.036188337951898575


In [7]:
rank_rse = hq.nsmallest(len(rse), rse)

rank_graph = []

for x in range(len(rse)):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for i in range(5):
    print(f'RSE = {rank_rse[i]:.3f}, graph number = {rank_graph[i]}')
print('')
print('The 5 worst RSE values are:')
for i in range(-1, -6, -1):
    print(f'RSE = {rank_rse[i]:.3f}, graph number = {rank_graph[i]}')


The 5 best RSE values are:
RSE = 0.016, graph number = 36
RSE = 0.022, graph number = 24
RSE = 0.024, graph number = 10
RSE = 0.024, graph number = 38
RSE = 0.026, graph number = 8

The 5 worst RSE values are:
RSE = 0.063, graph number = 33
RSE = 0.061, graph number = 3
RSE = 0.058, graph number = 17
RSE = 0.050, graph number = 23
RSE = 0.047, graph number = 14


In [8]:
# Plot best spectra predictions
p1 = bokeh_spectra(pred[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(pred[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(pred[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)