In [1]:
import pickle
import numpy as np
import os
import torch
import json
import pandas as pd
from torch_scatter import scatter_add
os.chdir(os.path.abspath('../'))
from qm9_allprop import QM9_allprop
from ase.io import read
from utils import read_xyz
from xgnn import xgnn_poly
from rdkit import Chem
import rdkit
from rdkit.Chem import Draw

In [2]:
args = json.load(open(f'./ckpt/HS_model/args.json','rt'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args["include_H"],args["include_S"] = True,True
model= xgnn_poly(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)

In [3]:
ckpt = torch.load('./ckpt/HS_model/ckpt/ckpt_best.pth')
model.load_state_dict(ckpt['model'])
model.eval()

xgnn_poly(
  (AF): SiLU()
  (emb_block): EmbeddingBlock(
    (AF): SiLU()
    (embedding): Embedding(10, 128, padding_idx=0, max_norm=3.0, scale_grad_by_freq=True)
    (lin): Linear(in_features=128, out_features=128, bias=True)
  )
  (envelop_function): poly_envelop()
  (sbf_layer): F_B_2D(
    (envelope): poly_envelop()
  )
  (rbf_layer): RadialBasis()
  (fin_model): SBFTransformer(
    (edgenn): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0): SBFTransformerConv(128, 8, heads=16)
      (1): SBFTransformerConv(128, 8, heads=16)
      (2): SBFTransformerConv(128, 8, heads=16)
      (3): SBFTransformerConv(128, 8, heads=16)
    )
    (readouts): ModuleList(
      (0): AtomWise(
        (mlp): ModuleList(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_fea

In [4]:
args = json.load(open(f'./ckpt/S_model/args.json','rt'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_S= xgnn_poly(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)

In [5]:
ckpt_S = torch.load('./ckpt/S_model/ckpt/ckpt_best.pth')
model_S.load_state_dict(ckpt_S['model'])
model_S.eval()

xgnn_poly(
  (AF): SiLU()
  (emb_block): EmbeddingBlock(
    (AF): SiLU()
    (embedding): Embedding(10, 128, padding_idx=0, max_norm=3.0, scale_grad_by_freq=True)
    (lin): Linear(in_features=128, out_features=128, bias=True)
  )
  (envelop_function): poly_envelop()
  (sbf_layer): F_B_2D(
    (envelope): poly_envelop()
  )
  (rbf_layer): RadialBasis()
  (fin_model): SBFTransformer(
    (edgenn): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0): SBFTransformerConv(128, 8, heads=16)
      (1): SBFTransformerConv(128, 8, heads=16)
      (2): SBFTransformerConv(128, 8, heads=16)
      (3): SBFTransformerConv(128, 8, heads=16)
    )
    (readouts): ModuleList(
      (0): AtomWise(
        (mlp): ModuleList(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_fea

In [None]:
def comparison_dataset(file_name):
    dataset = QM9_allprop(input_file=f'./raw/{file_name}.extxyz')
    dataset.data.atom_pos = dataset.data.atom_pos.float()
    dataset.data.edge_attr = dataset.data.edge_attr.float()
    atom_ref = torch.tensor([torch.nan,-0.500273,torch.nan,torch.nan,torch.nan,torch.nan,
                                -37.846772,-54.583861,-75.064579,-99.718730])
    atom_affi = torch.arange(len(dataset)).repeat_interleave(dataset.slices['x'][1:] - dataset.slices['x'][:-1])
    mol_ref = scatter_add(atom_ref[dataset.data.x],index=atom_affi,dim=0)
    dataset.data.y = dataset.data.y[:,4].squeeze() - mol_ref

    HS_p = torch.zeros(len(dataset))
    S_p = torch.zeros(len(dataset))
    with torch.no_grad():
        for i,data in enumerate(dataset):
            data = data.to('cuda')
            HS_pred = model(data).detach().cpu()
            HS_p[i]=HS_pred
            S_pred = model_S(data).detach().cpu()
            S_p[i] = S_pred

    S_t_preds = S_p.squeeze(-1)/0.04336414
    S_delta = S_t_preds - dataset.data.y * 27.211385056 / 0.04336414
    S_absolute = torch.abs(S_delta)

    HS_t_preds = HS_p.squeeze(-1)/0.04336414
    HS_delta = HS_t_preds - dataset.data.y * 27.211385056 / 0.04336414
    HS_absolute = torch.abs(HS_delta)

    return S_absolute, HS_absolute

In [8]:
from utils import read_xyz_allprop
file_name = 'PAHs'
outputs = read_xyz_allprop(f'./raw/{file_name}.extxyz')

# 测试每原子的平均误差
from qm9_allprop import QM9_allprop
dataset = QM9_allprop(input_file=f'./raw/{file_name}.extxyz')
dataset.data.atom_pos = dataset.data.atom_pos.float()
dataset.data.edge_attr = dataset.data.edge_attr.float()
atom_ref = torch.tensor([torch.nan,-0.500273,torch.nan,torch.nan,torch.nan,torch.nan,
                            -37.846772,-54.583861,-75.064579,-99.718730])
atom_affi = torch.arange(len(dataset)).repeat_interleave(dataset.slices['x'][1:] - dataset.slices['x'][:-1])
mol_ref = scatter_add(atom_ref[dataset.data.x],index=atom_affi,dim=0)
dataset.data.y = dataset.data.y[:,4].squeeze() - mol_ref

HS_p = torch.zeros(len(dataset))
S_p = torch.zeros(len(dataset))
with torch.no_grad():
    for i,data in enumerate(dataset):
        data = data.to('cuda')
        HS_pred = model(data).detach().cpu()
        HS_p[i]=HS_pred
        S_pred = model_S(data).detach().cpu()
        S_p[i] = S_pred

In [11]:
S_t_preds = S_p.squeeze(-1)/0.04336414
S_delta = S_t_preds - dataset.data.y * 27.211385056 / 0.04336414
S_absolute = torch.abs(S_delta)

HS_t_preds = HS_p.squeeze(-1)/0.04336414
HS_delta = HS_t_preds - dataset.data.y * 27.211385056 / 0.04336414
HS_absolute = torch.abs(HS_delta)