In [1]:
import heapq as hq
import torch
import optuna
from bokeh.io import output_notebook, show, export_svg
from bokeh.layouts import row
output_notebook()
from sklearn.metrics import mean_squared_error
from models import GNN_model
from utils.functions import get_spec_prediction, calculate_rse, bokeh_spectra
from dataset import XASMolDataset

In [2]:
gnn_type = 'gcn'

study_name = gnn_type + '_study'
storage_name = 'sqlite:///{}.db'.format(study_name)

study = optuna.load_study(study_name=study_name, storage=storage_name)

In [3]:
best = study.best_trial
print(f'Best model: {best.number}')
print(f'Best validation loss: {best.value:.5f}')
print('-----')
gnn_params = {}
for key, value in best.params.items():
    print(f'{key}: {value}')
    gnn_params[key] = value

Best model: 35
Best validation loss: 0.00824
-----
num_layers: 3
hidden_channels: 128
gnn_type: gcn
drop_ratio: 0.25
learning_rate: 0.015860148783355395


In [4]:
num_tasks = 200
num_layers = gnn_params['num_layers']
hidden_channels = gnn_params['hidden_channels']
in_channels = [hidden_channels] * (num_layers - 1)
in_channels.insert(0, 15)
out_channels = [hidden_channels] * num_layers
gnn_type = gnn_params['gnn_type']
heads = 1
drop_ratio = gnn_params['drop_ratio']
graph_pooling = 'mean'


model  = GNN_model(num_tasks, num_layers, in_channels, out_channels,
                   gnn_type, heads, drop_ratio, graph_pooling)

In [5]:
model.load_state_dict(torch.load('./gcn_model.pth'))

  model.load_state_dict(torch.load('./gcn_model.pth'))


<All keys matched successfully>

In [6]:
test_dataset = torch.load('../processed/test_coronene_pyg.pt')
print(test_dataset[0])

Data(x=[31, 15], edge_index=[2, 78], edge_attr=[78, 5], spectrum=[200], idx=[1], smiles='c12[c:4]3[c:9]4[c:13]5[c:17]6[c:19]7[c:20]([H:37])[c:23]([H:38])[c:24]([c:22]16)=[C:25]([H:42])[C:1]([H:31])([H:43])[C:2]2([O:27][H:39])[C:3]([H:32])=[C:5]([C:8](=[O:30])[H:41])[C:6]3([H:44])[C:7]1([H:33])[C:10]([H:34])([C:11]4=[C:12]([H:46])[C:14]([H:35])([H:45])[C:15]52[C:16]([H:36])([C:18]=7[C:21](=[O:29])[H:40])[O:28]2)[O:26]1', mol='R_eOH_0_eCOOH__0_eEPOXY_1_eCHO_0_eKETO_2_iOH_1_iEPOXY_0_ieEPOXY_1_1')


  test_dataset = torch.load('../processed/test_coronene_pyg.pt')


In [7]:
pred = {}
true = {}
device = 'cpu'

for index, data in enumerate(test_dataset):
    pred[index], true[index] = get_spec_prediction(model, data, device)

In [8]:
mse = []
rse = []

for x in range(len(pred)):
    mse_temp = mean_squared_error(true[x], pred[x])
    mse.append(mse_temp)
    rse_temp = calculate_rse(pred[x], true[x])
    rse.append(rse_temp)

ave_mse = sum(mse) / len(mse)
ave_rse = sum(rse) / len(rse)

print(f'Average MSE = {ave_mse}')
print(f'Average RSE = {ave_rse}')

Average MSE = 0.006760057527571917
Average RSE = 0.0366290919482708


In [9]:
rank_rse = hq.nsmallest(len(rse), rse)

rank_graph = []

for x in range(len(rse)):
    rank_idx = rse.index(rank_rse[x])
    rank_graph.append(rank_idx)

print('The 5 best RSE values are:')
for i in range(5):
    print(f'RSE = {rank_rse[i]:.3f}, graph number = {rank_graph[i]}')
print('')
print('The 5 worst RSE values are:')
for i in range(-1, -6, -1):
    print(f'RSE = {rank_rse[i]:.3f}, graph number = {rank_graph[i]}')


The 5 best RSE values are:
RSE = 0.018, graph number = 27
RSE = 0.020, graph number = 22
RSE = 0.020, graph number = 33
RSE = 0.023, graph number = 26
RSE = 0.024, graph number = 2

The 5 worst RSE values are:
RSE = 0.068, graph number = 17
RSE = 0.066, graph number = 37
RSE = 0.058, graph number = 11
RSE = 0.056, graph number = 3
RSE = 0.051, graph number = 24


In [10]:
# Plot best spectra predictions
p1 = bokeh_spectra(pred[rank_graph[0]], true[rank_graph[0]])
p2 = bokeh_spectra(pred[rank_graph[1]], true[rank_graph[1]])
p3 = bokeh_spectra(pred[rank_graph[2]], true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)

In [11]:
# Plot worst spectra predictions
p1 = bokeh_spectra(pred[rank_graph[-1]], true[rank_graph[-1]])
p2 = bokeh_spectra(pred[rank_graph[-2]], true[rank_graph[-2]])
p3 = bokeh_spectra(pred[rank_graph[-3]], true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)

In [12]:
path = '../'
cir_dataset = XASMolDataset(path)

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):


In [13]:
cir_pred = {}
cir_true = {}
device = 'cpu'

for index, data in enumerate(cir_dataset):
    cir_pred[index], cir_true[index] = get_spec_prediction(model, data, device)

In [14]:
cir_mse = []
cir_rse = []

for x in range(len(cir_pred)):
    mse_temp = mean_squared_error(cir_true[x], cir_pred[x])
    cir_mse.append(mse_temp)
    rse_temp = calculate_rse(cir_pred[x], cir_true[x])
    cir_rse.append(rse_temp)

cir_ave_mse = sum(cir_mse) / len(cir_mse)
cir_ave_rse = sum(cir_rse) / len(cir_rse)

print(f'Average MSE = {cir_ave_mse}')
print(f'Average RSE = {cir_ave_rse}')

Average MSE = 0.006466479506343603
Average RSE = 0.05095427483320236


In [15]:
cir_rank_rse = hq.nsmallest(len(cir_rse), cir_rse)

cir_rank_graph = []

for x in range(len(cir_rse)):
    cir_rank_idx = cir_rse.index(cir_rank_rse[x])
    cir_rank_graph.append(cir_rank_idx)

print('The 5 best RSE values are:')
for i in range(5):
    print(f'RSE = {cir_rank_rse[i]:.3f}, graph number = {cir_rank_graph[i]}')
print('')
print('The 5 worst RSE values are:')
for i in range(-1, -6, -1):
    print(f'RSE = {cir_rank_rse[i]:.3f}, graph number = {cir_rank_graph[i]}')

The 5 best RSE values are:
RSE = 0.034, graph number = 29
RSE = 0.037, graph number = 5
RSE = 0.038, graph number = 11
RSE = 0.040, graph number = 63
RSE = 0.040, graph number = 3

The 5 worst RSE values are:
RSE = 0.073, graph number = 22
RSE = 0.072, graph number = 1
RSE = 0.069, graph number = 7
RSE = 0.069, graph number = 6
RSE = 0.068, graph number = 86


In [16]:
# Plot best spectra predictions
p1 = bokeh_spectra(cir_pred[cir_rank_graph[0]], cir_true[cir_rank_graph[0]])
p2 = bokeh_spectra(cir_pred[cir_rank_graph[1]], cir_true[cir_rank_graph[1]])
p3 = bokeh_spectra(cir_pred[cir_rank_graph[2]], cir_true[rank_graph[2]])
p = row(p1, p2, p3)
show(p)

In [17]:
# Plot worst spectra predictions
p1 = bokeh_spectra(cir_pred[cir_rank_graph[-1]], cir_true[cir_rank_graph[-1]])
p2 = bokeh_spectra(cir_pred[cir_rank_graph[-2]], cir_true[cir_rank_graph[-2]])
p3 = bokeh_spectra(cir_pred[cir_rank_graph[-3]], cir_true[rank_graph[-3]])
p = row(p1, p2, p3)
show(p)