In [1]:
import numpy as np
import os
os.chdir(os.path.abspath('../'))
os.system('pwd')
from xgnn import xgnn_poly, xgnn_poly_global, xgnn_poly_noattn
from xgnn_equi import XGNN_Equi_force, XGNN_Equi
import json
import torch
from torch_geometric.loader import DataLoader
from qm9_allprop import QM9_allprop

/home/zfwang/X2-GNN


In [2]:
args = {}
save_dir = os.path.abspath("./ckpt/HS_model")
ckpt_mark = "U0"
bsz = 4

ckpt_path = os.path.abspath(save_dir)
ckpt = torch.load(os.path.join(ckpt_path, 'ckpt/ckpt_best.pth'))
with open(f'{ckpt_path}/args.json','rt') as f:
    args.update(json.load(f))

In [3]:
if 'attn' not in args:
	args['attn'] = True
if 'include_H' not in args:
	args['include_H'],args['include_S'] = True, True
if 'equi_model' not in args:
	args['equi_model'] = False

In [4]:
device="cuda"
if args['target'] in [5,6,7,8,9,10,11]:
    if not args["equi_model"]:
        if not args["attn"]:
            model = xgnn_poly_noattn(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)
        else:        
            model = xgnn_poly(include_H = args["include_H"], include_S = args["include_S"], conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device).to(device)
    else:
        model = XGNN_Equi(conv_layers=args['conv_layers'], rbf_dim=args['rbf_dim'], vector_irreps=args['vector_irreps'], heads=args['heads'], hidden_dim = args['embedding_size'], device = device).to(device)
else:
    model = xgnn_poly_global(conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], device=device, pool_option=args['pool_option']).to(device)   #model = Lora_xgnn_poly(conv_layers=args['conv_layers'], sbf_dim=args['sbf_dim'], rbf_dim=args['rbf_dim'], in_channels=args['in_channels'], heads=args['heads'], embedding_size=args['embedding_size'], rank=args['lora_rank'], device=device).to(device)
ema_avg = lambda averaged_model_parameter, model_parameter:\
        args['ema_decay'] * averaged_model_parameter + (1-args['ema_decay']) * model_parameter
ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

model.load_state_dict(ckpt['model'], strict=False)
ema_model.load_state_dict(ckpt['ema_model'], strict=False)

<All keys matched successfully>

In [6]:
# AID dataset
dataset = QM9_allprop(input_file='./raw/AID_allprop.xyz')
loader = DataLoader(dataset=dataset,batch_size=bsz)

preds = torch.zeros(len(dataset))
start = 0
end = 0
with torch.no_grad():
    for data in loader:
        data=data.to(device)
        end += data.num_graphs
        preds[start:end] = model(data).detach()
        start += data.num_graphs
torch.save(preds,f'./pickles/AID_{ckpt_mark}.pt')

t_preds = torch.tensor(preds).squeeze(-1)/0.04336414
delta = t_preds - dataset.data.y
absolute = torch.abs(delta)

atom_ref = np.array([np.nan, -0.500273, np.nan, np.nan, np.nan, np.nan,-37.84677172,-54.583861,-75.064579,-99.718730])

from utils import read_xyz
outputs = read_xyz('./raw/AID)allprop.xyz')

Formula_list = []
Heavy_Num = []
sub_map = str.maketrans('0123456789', '₀₁₂₃₄₅₆₇₈₉')

for i,mol in enumerate(outputs):
    C_num = mol.Z.tolist().count(6)
    ChemFor = f'C{C_num}'
    H_num = mol.Z.tolist().count(1)
    if H_num:
        ChemFor += f'H{H_num}'
    N_num = mol.Z.tolist().count(7)
    if N_num:
        ChemFor += f'N{N_num}'
    O_num = mol.Z.tolist().count(8)
    if O_num:
        ChemFor += f'O{O_num}'
    F_num = mol.Z.tolist().count(9)
    if F_num:
        ChemFor += f'F{F_num}'
    Formula_list.append(ChemFor.translate(sub_map))
    Heavy_Num.append(C_num+N_num+O_num+F_num)

ids = range(len(dataset))

import openpyxl
wb=openpyxl.Workbook()
ws=wb["Sheet"]
ws.cell(row = 1,column=1).value = 'id'
ws.cell(row = 1,column=2).value = 'Fomula'
ws.cell(row = 1,column=3).value = 'Heavy'
ws.cell(row = 1,column=4).value = 'Total'
ws.cell(row = 1, column = 5).value = 'Label'
ws.cell(row = 1, column = 6).value = 'preds'
ws.cell(row = 1, column = 7).value = 'delta'
ws.cell(row = 1, column = 8).value = 'abs'
ws.cell(row = 1, column = 9).value = 'per_heavy'
ws.cell(row = 1, column = 10).value = 'per_atom'
ws.cell(row = 1, column = 11).value = delta.mean().numpy().item()
ws.cell(row = 1, column = 12).value = absolute.mean().numpy().item()
for i in range(len(outputs)):
    ws.cell(row=i+2,column=1).value = ids[i]
    ws.cell(row=i+2,column=2).value = Formula_list[i]
    ws.cell(row=i+2,column=3).value = Heavy_Num[i]
    ws.cell(row=i+2,column=4).value = outputs[i].N.item()
    ws.cell(row=i+2,column=5).value=dataset.data.y[i].numpy().item()
    ws.cell(row=i+2,column=6).value=t_preds[i].numpy().item()
    ws.cell(row = i+2, column = 7).value = delta[i].numpy().item()
    ws.cell(row = i+2, column = 8).value = absolute[i].numpy().item()
    ws.cell(row = i+2, column = 9).value = absolute[i].numpy().item()/Heavy_Num[i]
    ws.cell(row = i+2, column = 10).value = absolute[i].numpy().item()/outputs[i].N.item()
wb.save(f'./results_AID_{ckpt_mark}.xlsx')
print('Done')

  t_preds = torch.tensor(preds).squeeze(-1)/0.04336414


Done


In [None]:
# OCELOT dataset
dataset = QM9_allprop(input_file='./raw/ocelot_all.xyz')
loader = DataLoader(dataset=dataset,batch_size=bsz)

preds = torch.zeros(len(dataset))
start = 0
end = 0
with torch.no_grad():
    for data in loader:
        data=data.to(device)
        end += data.num_graphs
        preds[start:end] = model(data).detach()
        start += data.num_graphs
torch.save(preds,f'./pickles/ocelot_{ckpt_mark}.pt')

t_preds = torch.tensor(preds).squeeze(-1)/0.04336414
delta = t_preds - dataset.data.y
absolute = torch.abs(delta)

atom_ref = np.array([np.nan, -0.500273, np.nan, np.nan, np.nan, np.nan,-37.84677172,-54.583861,-75.064579,-99.718730])

from utils import read_xyz
outputs = read_xyz('./raw/ocelot_all.xyz')

Formula_list = []
Heavy_Num = []
sub_map = str.maketrans('0123456789', '₀₁₂₃₄₅₆₇₈₉')

for i,mol in enumerate(outputs):
    C_num = mol.Z.tolist().count(6)
    ChemFor = f'C{C_num}'
    H_num = mol.Z.tolist().count(1)
    if H_num:
        ChemFor += f'H{H_num}'
    N_num = mol.Z.tolist().count(7)
    if N_num:
        ChemFor += f'N{N_num}'
    O_num = mol.Z.tolist().count(8)
    if O_num:
        ChemFor += f'O{O_num}'
    F_num = mol.Z.tolist().count(9)
    if F_num:
        ChemFor += f'F{F_num}'
    Formula_list.append(ChemFor.translate(sub_map))
    Heavy_Num.append(C_num+N_num+O_num+F_num)

ids = range(len(dataset))

import openpyxl
wb=openpyxl.Workbook()
ws=wb["Sheet"]
ws.cell(row = 1,column=1).value = 'id'
ws.cell(row = 1,column=2).value = 'Fomula'
ws.cell(row = 1,column=3).value = 'Heavy'
ws.cell(row = 1,column=4).value = 'Total'
ws.cell(row = 1, column = 5).value = 'Label'
ws.cell(row = 1, column = 6).value = 'preds'
ws.cell(row = 1, column = 7).value = 'delta'
ws.cell(row = 1, column = 8).value = 'abs'
ws.cell(row = 1, column = 9).value = 'per_heavy'
ws.cell(row = 1, column = 10).value = 'per_atom'
ws.cell(row = 1, column = 11).value = delta.mean().numpy().item()
ws.cell(row = 1, column = 12).value = absolute.mean().numpy().item()
for i in range(len(outputs)):
    ws.cell(row=i+2,column=1).value = ids[i]
    ws.cell(row=i+2,column=2).value = Formula_list[i]
    ws.cell(row=i+2,column=3).value = Heavy_Num[i]
    ws.cell(row=i+2,column=4).value = outputs[i].N.item()
    ws.cell(row=i+2,column=5).value=dataset.data.y[i].numpy().item()
    ws.cell(row=i+2,column=6).value=t_preds[i].numpy().item()
    ws.cell(row = i+2, column = 7).value = delta[i].numpy().item()
    ws.cell(row = i+2, column = 8).value = absolute[i].numpy().item()
    ws.cell(row = i+2, column = 9).value = absolute[i].numpy().item()/Heavy_Num[i]
    ws.cell(row = i+2, column = 10).value = absolute[i].numpy().item()/outputs[i].N.item()
wb.save(f'./results_ocelot_{ckpt_mark}.xlsx')
print('Done')