In [2]:
import pickle
import numpy as np
import os
import torch
import json
import pandas as pd
from torch_scatter import scatter_add
os.chdir(os.path.abspath('../'))
from qm9_allprop import QM9_allprop
from ase.io import read
from utils import read_xyz
from xgnn import xgnn_poly
from rdkit import Chem
import rdkit
from rdkit.Chem import Draw

In [3]:
args = json.load(open(f'./ckpt/HS_model/args.json','rt'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args["include_H"],args["include_S"] = True,True
model= xgnn_poly(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)

In [4]:
ckpt = torch.load('./ckpt/HS_model/ckpt/ckpt_best.pth')
model.load_state_dict(ckpt['model'])
model.eval()

xgnn_poly(
  (AF): SiLU()
  (emb_block): EmbeddingBlock(
    (AF): SiLU()
    (embedding): Embedding(10, 128, padding_idx=0, max_norm=3.0, scale_grad_by_freq=True)
    (lin): Linear(in_features=128, out_features=128, bias=True)
  )
  (envelop_function): poly_envelop()
  (sbf_layer): F_B_2D(
    (envelope): poly_envelop()
  )
  (rbf_layer): RadialBasis()
  (fin_model): SBFTransformer(
    (edgenn): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0): SBFTransformerConv(128, 8, heads=16)
      (1): SBFTransformerConv(128, 8, heads=16)
      (2): SBFTransformerConv(128, 8, heads=16)
      (3): SBFTransformerConv(128, 8, heads=16)
    )
    (readouts): ModuleList(
      (0): AtomWise(
        (mlp): ModuleList(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_fea

In [5]:
args = json.load(open(f'./ckpt/S_model/args.json','rt'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_S= xgnn_poly(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)

In [31]:
ckpt_S = torch.load('./ckpt/S_model/ckpt/ckpt_best.pth')
model_S.load_state_dict(ckpt_S['model'])
model_S.eval()

xgnn_poly(
  (AF): SiLU()
  (emb_block): EmbeddingBlock(
    (AF): SiLU()
    (embedding): Embedding(10, 128, padding_idx=0, max_norm=3.0, scale_grad_by_freq=True)
    (lin): Linear(in_features=128, out_features=128, bias=True)
  )
  (envelop_function): poly_envelop()
  (sbf_layer): F_B_2D(
    (envelope): poly_envelop()
  )
  (rbf_layer): RadialBasis()
  (fin_model): SBFTransformer(
    (edgenn): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0): SBFTransformerConv(128, 8, heads=16)
      (1): SBFTransformerConv(128, 8, heads=16)
      (2): SBFTransformerConv(128, 8, heads=16)
      (3): SBFTransformerConv(128, 8, heads=16)
    )
    (readouts): ModuleList(
      (0): AtomWise(
        (mlp): ModuleList(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_fea

In [34]:
def comparison_dataset(file_name, prop_len, prop_index):
    dataset = QM9_allprop(input_file=f'./raw/{file_name}', prop_len=prop_len)
    dataset.data.atom_pos = dataset.data.atom_pos.float()
    dataset.data.edge_attr = dataset.data.edge_attr.float()
    atom_ref = torch.tensor([torch.nan,-0.500273,torch.nan,torch.nan,torch.nan,torch.nan,
                                -37.846772,-54.583861,-75.064579,-99.718730])
    atom_affi = torch.arange(len(dataset)).repeat_interleave(dataset.slices['x'][1:] - dataset.slices['x'][:-1])
    mol_ref = scatter_add(atom_ref[dataset.data.x],index=atom_affi,dim=0)
    if prop_len != 1:
        dataset.data.y = dataset.data.y[:,prop_index].squeeze() - mol_ref
    else:
        dataset.data.y = dataset.data.y.squeeze() - mol_ref

    HS_p = torch.zeros(len(dataset))
    S_p = torch.zeros(len(dataset))
    with torch.no_grad():
        for i,data in enumerate(dataset):
            data = data.to('cuda')
            HS_pred = model(data).detach().cpu()
            HS_p[i]=HS_pred
            S_pred = model_S(data).detach().cpu()
            S_p[i] = S_pred

    S_t_preds = S_p.squeeze(-1)/0.04336414
    S_delta = S_t_preds - dataset.data.y * 27.211385056 / 0.04336414
    S_absolute = torch.abs(S_delta)

    HS_t_preds = HS_p.squeeze(-1)/0.04336414
    HS_delta = HS_t_preds - dataset.data.y * 27.211385056 / 0.04336414
    HS_absolute = torch.abs(HS_delta)

    return S_t_preds,S_absolute, HS_t_preds, HS_absolute

In [35]:
alkanes_res = comparison_dataset('alkanes.xyz', 1, 0)
polyenes_res = comparison_dataset('polyenes.xyz', 1, 0)
PAHs_res = comparison_dataset('PAHs.extxyz', 9, 8)

In [37]:
for res in [alkanes_res, polyenes_res, PAHs_res]:
    s_p, s_mae, hs_p, hs_mae = res
    print(s_p, hs_p)

tensor([ -595.9776,  -672.0606,  -948.0549, -1224.4592, -1501.1942, -1777.9180,
        -2054.4985, -2331.1096, -2607.6929, -2884.2759, -3160.8442, -3437.4160]) tensor([ -396.0373,  -670.8577,  -947.6395, -1224.5973, -1501.3271, -1778.2006,
        -2055.0557, -2331.9207, -2608.7617, -2885.6016, -3162.4353, -3439.2825])
tensor([-4423.7578, -1824.9060, -3556.9897, -3123.8848, -1392.6462, -2690.7322,
        -2257.6973,  -960.5262, -3990.3064,  -529.7487]) tensor([-4404.0474, -1822.0908, -3543.3799, -3113.0481, -1391.9341, -2682.7178,
        -2252.3982,  -961.4845, -3973.7061,  -530.2637])
tensor([-8442.2432, -8122.0625, -6934.5190, -7981.4038, -8121.9600, -6796.3560,
        -6929.1406, -8885.1260, -6675.3613, -8123.0483, -8879.5498, -6932.2178,
        -7245.3267, -8878.0107, -7546.9253, -6991.3325, -8000.6836, -8123.6572,
        -7236.8901, -7364.7305, -7665.8950, -8448.2061, -7121.3281, -7362.9990,
        -6796.9600, -6803.9346, -7682.1475, -8432.4062, -7116.4038, -9641.5664,
    