In [1]:
import sys
import json
import os

from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
import numpy as np

import torch
import torch.nn as nn

from dGbyG import mol_to_graph_data, MP_network, threo, best_model_params

In [2]:
class networks(nn.Module):
    def __init__(self, dir) -> None:
        super().__init__()
        self.nets = nn.ModuleList([])
        for file in os.listdir(dir):
            path = os.path.join(dir, file)
            net = MP_network(atom_dim=139, bond_dim=23, emb_dim=300, num_layer=2)
            net.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
            self.nets.append(net)
        self.num = len(self.nets)
    
    def forward(self, data):
        outputs = torch.zeros(size=(self.num, 1)).to(data.x.device)  # shape=[number of net, 1]

        for i, net in enumerate(self.nets):
            outputs[i] = net(data)  # net.shape = [1,1] or [atom number, 1]

        return outputs  # .squeeze()


# 使用方式
network = networks(best_model_params)
network.eval()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
network.to(device)


networks(
  (nets): ModuleList(
    (0-99): 100 x MP_network(
      (atom_lin): Linear(in_features=139, out_features=300, bias=True)
      (bond_lin): Linear(in_features=23, out_features=300, bias=True)
      (MP_layers): ModuleList(
        (0-1): 2 x MP_layer()
      )
      (energy_lin): Sequential(
        (0): ReLU()
        (1): Linear(in_features=300, out_features=300, bias=True)
        (2): ReLU()
        (3): Linear(in_features=300, out_features=150, bias=True)
        (4): ReLU()
        (5): Linear(in_features=150, out_features=1, bias=False)
      )
    )
  )
)

In [5]:
def predict_standard_dGf_prime(
    inchi: str, 
    network: nn.Module, 
    device: str, 
    mode: str = 'molecule mode'
) -> tuple:
    """
    预测标准吉布斯自由能
    
    Args:
        inchi: 分子的 InChI 字符串
        network: 预测网络
        device: 计算设备
        mode: 'molecule mode' 或 'atom mode'
    
    Returns:
        (mean, std)
    """
    try:
        # 1. 从 InChI 读取分子
        mol = Chem.MolFromInchi(inchi, removeHs=False, sanitize=True)
        if mol is None:
            raise ValueError(f"无法解析 InChI: {inchi}")
        
        # 2. 第一次 normalize（Normalize + Uncharger）
        mol = rdMolStandardize.Normalize(mol)
        mol = rdMolStandardize.Uncharger().uncharge(mol)
        
        # 3. 转换成 SMILES 再读回来（关键步骤！）
        smiles = Chem.MolToSmiles(mol)
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        
        # 4. 第二次 normalize
        mol = rdMolStandardize.Normalize(mol)
        mol = rdMolStandardize.Uncharger().uncharge(mol)
        
        # 5. 添加氢原子
        mol = Chem.AddHs(mol)
        
        # 6. 转换为图数据
        data = mol_to_graph_data(mol).to(device)
        
        # 7. 预测
        with torch.no_grad():
            predictions = network(data).cpu().numpy()
        
        print(predictions)

        mean_pred = np.mean(predictions)
        std_pred = np.std(predictions)
        return mean_pred, std_pred
        
    
    except Exception as e:
        print(f"预测失败 - InChI: {inchi}, 错误: {e}")
        return np.nan, np.nan


In [6]:
# 使用示例
inchi = "InChI=1S/C6H12O6/c7-1-2-3(8)4(9)5(10)6(11)12-2/h2-11H,1H2/t2-,3-,4+,5-,6?/m1/s1"
mean, std = predict_standard_dGf_prime(inchi, network, device)
print(f"预测值: {mean:.2f} ± {std:.2f}")

[01:07:32] Initializing Normalizer
[01:07:32] Running Normalizer
[01:07:32] Running Uncharger
[01:07:32] Initializing Normalizer
[01:07:32] Running Normalizer
[01:07:32] Running Uncharger


[[-423.66382]
 [-424.16553]
 [-423.28647]
 [-424.70627]
 [-423.76373]
 [-423.99884]
 [-424.60098]
 [-423.46918]
 [-424.2022 ]
 [-424.1288 ]
 [-424.53967]
 [-424.85046]
 [-423.49292]
 [-423.67776]
 [-424.84042]
 [-424.66394]
 [-424.76013]
 [-423.00464]
 [-424.10117]
 [-425.02606]
 [-424.8807 ]
 [-424.0773 ]
 [-424.5262 ]
 [-424.92255]
 [-424.51514]
 [-424.88535]
 [-425.6269 ]
 [-424.6825 ]
 [-425.46802]
 [-423.60117]
 [-423.65677]
 [-424.99573]
 [-423.85922]
 [-424.79523]
 [-424.38586]
 [-425.10565]
 [-424.92474]
 [-425.89398]
 [-423.78992]
 [-425.00082]
 [-425.54065]
 [-424.57733]
 [-423.03598]
 [-424.57056]
 [-423.84772]
 [-424.55597]
 [-424.71564]
 [-424.81488]
 [-424.86362]
 [-424.47864]
 [-424.9416 ]
 [-424.09766]
 [-424.18262]
 [-424.328  ]
 [-424.08752]
 [-424.92703]
 [-424.37866]
 [-424.95782]
 [-423.99866]
 [-424.45718]
 [-423.60495]
 [-424.15564]
 [-424.65192]
 [-424.70248]
 [-423.1867 ]
 [-425.3801 ]
 [-425.2685 ]
 [-424.2978 ]
 [-424.43405]
 [-424.28763]
 [-425.3189 ]
 [-424