In [193]:
# Load dependence
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
import dgl
import torch
import torch.nn as nn
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, mol_to_bigraph
import dgl.function as fn
from sklearn.preprocessing import OneHotEncoder
import dgl.nn.pytorch as dglnn
import torch.nn.functional as F
from sklearn.model_selection import KFold
from torch.utils.data import Dataset
from dgl.dataloading import GraphDataLoader

In [195]:
#Load raw data
df = pd.read_csv('./RuCHFunctionalizationDataset/dataset.csv')
df_DG = pd.read_csv('./RuCHFunctionalizationDataset/DG.csv')
df_RX = pd.read_csv('./RuCHFunctionalizationDataset/RX.csv')

DG_smiles_list = df['DG'].to_list()
RX_smiles_list = df['RX'].to_list()
catalyst_smiles_list = df['catalyst'].to_list()
sol_smiles_list = df['solvent'].to_list()
ligand_smiles_list = df['ligand'].to_list()
ad_smiles_list = df['addictive'].to_list()

DG_mols_list = []
for smi in DG_smiles_list:
    mol = Chem.MolFromSmiles(smi)
    DG_mols_list.append(mol)
    
train_val_num = len(DG_mols_list)
    
RX_mols_list = []
for smi in RX_smiles_list:
    mol = Chem.MolFromSmiles(smi)
    RX_mols_list.append(mol)

DG_mols_list = []
for smi in DG_smiles_list:
    mol = Chem.MolFromSmiles(smi)
    DG_mols_list.append(mol)
    
train_val_num = len(DG_mols_list)
    
RX_mols_list = []
for smi in RX_smiles_list:
    mol = Chem.MolFromSmiles(smi)
    RX_mols_list.append(mol)

In [197]:
#Load target for the classification task
target = df['tag'].to_list()
target = np.array(target)

num = df['number'].to_list()
num = np.array(num)
num_DG = df['DG_num'].to_list()
num_DG = np.array(num_DG)
num_RX = df['RX_num'].to_list()
num_RX = np.array(num_RX)

In [199]:
# Prepare descriptors
node_featurizer = CanonicalAtomFeaturizer(atom_data_field='nfeat')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='efeat')
enc = OneHotEncoder(sparse = False)

ligand_all = ligand_smiles_list
ligand_all = np.array(ligand_all)
ligand_all = ligand_all.reshape(-1, 1)
l_descriptors_all = enc.fit_transform(ligand_all)
l_descriptors = l_descriptors_all
l_descriptors = torch.from_numpy(l_descriptors).type(torch.float)
l_descriptors_n = l_descriptors.shape[1]

catalyst_all = catalyst_smiles_list
catalyst_all = np.array(catalyst_all)
catalyst_all = catalyst_all.reshape(-1, 1)
c_descriptors_all = enc.fit_transform(catalyst_all)
c_descriptors = c_descriptors_all
c_descriptors = torch.from_numpy(c_descriptors).type(torch.float)
c_descriptors_n = c_descriptors.shape[1]

sol_all = sol_smiles_list
sol_all = np.array(sol_all)
sol_all = sol_all.reshape(-1, 1)
sol_descriptors_all = enc.fit_transform(sol_all)
sol_descriptors = sol_descriptors_all
sol_descriptors = torch.from_numpy(sol_descriptors).type(torch.float)
sol_descriptors_n = sol_descriptors.shape[1]

ad_all = ad_smiles_list
ad_all = np.array(ad_all)
ad_all = ad_all.reshape(-1, 1)
ad_descriptors_all = enc.fit_transform(ad_all)
ad_descriptors = ad_descriptors_all
ad_descriptors = torch.from_numpy(ad_descriptors).type(torch.float)
ad_descriptors_n = ad_descriptors.shape[1]

In [201]:
# Find substituted arenes
def ortho_substituted(mol):
    mol_no_H = AllChem.RemoveHs(mol)
    for idx, atom in enumerate(mol_no_H.GetAtoms()):
        if idx == 1 :
            if atom.GetDegree() > 2:
                return True
        elif idx == 5:
            if atom.GetDegree() > 2:
                return True
            
def meta_substituted(mol):
    mol_no_H = AllChem.RemoveHs(mol)
    for idx, atom in enumerate(mol_no_H.GetAtoms()):
        if idx == 2 :
            if atom.GetDegree() > 2:
                return True
        elif idx == 4:
            if atom.GetDegree() > 2:
                return True
            

ortho_sub_list = []
meta_sub_list = []

for idx, mol in enumerate(DG_mols_list):
    if ortho_substituted(mol) == True:
        ortho_sub_list.append(idx)

for idx, mol in enumerate(DG_mols_list):
    if meta_substituted(mol) == True:
        meta_sub_list.append(idx)

In [203]:
# Create reaction graphs
def make_reaction_graphs(DG_mols_list, RX_mols_list, l_desc, c_desc, sol_desc, ad_desc):
    n = 0
    DG_graphs = []
    RX_graphs = []
    ligand_graphs = []
    catalyst_graphs = []
    solvent_graphs = []
    addictive_graphs = []
    reaction_graphs = []
    v1_graphs = []
    v2_graphs = []
    decs = []
    decs_mp = []

    for mol1, mol2 in zip(DG_mols_list, RX_mols_list):
        g1 = mol_to_bigraph(mol1,node_featurizer=node_featurizer,edge_featurizer=edge_featurizer)
        decs.append(g1.ndata['nfeat'])
        DG_graphs.append(g1)
        virtual_node_id1 = g1.number_of_nodes()# 获取当前节点数作为新节点ID
        g1 = dgl.add_nodes(g1, 1) # 添加一个新节点
        g1 = dgl.add_edges(g1,[virtual_node_id1] * (g1.number_of_nodes()-1), list(range(g1.number_of_nodes()-1)))# 将新节点连接到所有节点
        g1.update_all(fn.copy_u('nfeat', 'm'), fn.sum('m', 'h1'))##进行消息传递
        decs_mp.append(g1.ndata['h1'])
        
        
        g1_ndata_num = g1.ndata['h1'].shape[1]###是74？
        g_v1 = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建虚拟节点1的图
        g_v1.ndata['h'] = torch.ones(1, g1_ndata_num)
        g_v1_f = g1.ndata['h1'][virtual_node_id1]
        g_v1.ndata['h'][0] = g_v1_f
        v1_graphs.append(g_v1)
        
        g2 = mol_to_bigraph(mol2,node_featurizer=node_featurizer,edge_featurizer=edge_featurizer)
        RX_graphs.append(g2)
        virtual_node_id2 = g2.number_of_nodes()# 获取当前节点数作为新节点ID
        g2 = dgl.add_nodes(g2, 1) # 添加一个新节点
        g2 = dgl.add_edges(g2,[virtual_node_id2] * (g2.number_of_nodes()-1), list(range(g2.number_of_nodes()-1)))# 将新节点连接到所有节点
        g2.update_all(fn.copy_u('nfeat', 'm'), fn.sum('m', 'h2'))##进行消息传递
        
        
        g2_ndata_num = g2.ndata['h2'].shape[1]
        g_v2 = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建虚拟节点2的图
        g_v2.ndata['h'] = torch.ones(1, g2_ndata_num)
        g_v2_f = g2.ndata['h2'][virtual_node_id2]
        g_v2.ndata['h'][0] = g_v2_f
        v2_graphs.append(g_v2)
    
        g_l = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建ligand的图
        g_l.ndata['h'] = torch.ones(1, l_descriptors_n)
        g_l.ndata['h'][0] = l_desc[n]
        g_l.ndata['h'] = torch.cat([g_l.ndata['h'], torch.zeros((1, g1_ndata_num - l_descriptors_n))], dim=1)###为了后面合并图，将节点特征维度统一
        ligand_graphs.append(g_l)
        
        g_c = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建cat的图
        g_c.ndata['h'] = torch.ones(1, c_descriptors_n)
        g_c.ndata['h'][0] = c_desc[n]
        g_c.ndata['h'] = torch.cat([g_c.ndata['h'], torch.zeros((1, g1_ndata_num - c_descriptors_n))], dim=1)###为了后面合并图，将节点特征维度统一
        catalyst_graphs.append(g_c)
        
        g_s = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建sol的图
        g_s.ndata['h'] = torch.ones(1, sol_descriptors_n)
        g_s.ndata['h'][0] = sol_desc[n]
        g_s.ndata['h'] = torch.cat([g_s.ndata['h'], torch.zeros((1, g1_ndata_num - sol_descriptors_n))], dim=1)###为了后面合并图，将节点特征维度统一
        solvent_graphs.append(g_s)
        
        g_ad = dgl.graph((torch.tensor([0]), torch.tensor([0])))###创建ad的图
        g_ad.ndata['h'] = torch.ones(1, ad_descriptors_n)
        g_ad.ndata['h'][0] = ad_desc[n]
        g_ad.ndata['h'] = torch.cat([g_ad.ndata['h'], torch.zeros((1, g1_ndata_num - ad_descriptors_n))], dim=1)###为了后面合并图，将节点特征维度统一
        addictive_graphs.append(g_ad)
        
        g_r = dgl.batch([g_c, g_l, g_s, g_ad])##合并图
        g_r.add_edges([0]*3,range(1,4))###所有节点连边
        g_r.add_edges([1]*3,[0,2,3])###所有节点连边
        g_r.add_edges([2]*3,[0,1,3])###所有节点连边
        g_r.add_edges([3]*3,[0,1,2])###所有节点连边
        #g_r.add_edges([4]*4,[0,1,2,3])###所有节点连边
        #g_r.add_edges([5]*5,[0,1,2,3,4])###所有节点连边
        reaction_graphs.append(g_r)
        
        n = n+1
    return v1_graphs, v2_graphs, reaction_graphs, decs, decs_mp, DG_graphs, RX_graphs
   
v1_graphs, v2_graphs, reaction_graphs, decs, decs_mp, DG_graphs, RX_graphs = make_reaction_graphs(DG_mols_list, RX_mols_list, l_descriptors, c_descriptors, sol_descriptors, ad_descriptors)
reaction_graphs = np.array(reaction_graphs)
v1_graphs = np.array(v1_graphs)
v2_graphs = np.array(v2_graphs)
DG_graphs = np.array(DG_graphs)
RX_graphs = np.array(RX_graphs)

In [204]:
# Generate dataloader
class GraphDataset(Dataset):
    def __init__(self, graph_list, label_list, num_list):
        self.graph_list = graph_list
        self.label_list = label_list
        self.num_list = num_list
        
    def __len__(self):
        return len(self.graph_list)
        
    def __getitem__(self, idx):
        graph = self.graph_list[idx]
        labels = self.label_list[idx]
        num = self.num_list[idx]
        return graph, labels, num

target = torch.tensor(target, dtype=torch.long)
num = torch.tensor(num, dtype=torch.long)
dataset = GraphDataset(reaction_graphs, target, num)

dataloader = GraphDataLoader(
    dataset,
    batch_size=30,
    drop_last=False,
    shuffle=True)

In [207]:
# single-task classifier and evaluate function
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
        self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, h):
        # 应用图卷积和激活函数
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            # 使用平均读出计算图表示
            hg = dgl.mean_nodes(g, 'h')
            classify_output = self.classify(hg)
            return classify_output, hg
        
def evaluate(dataloader, model, o_sub_list, m_sub_list):###验证模型精确度函数
    model.eval()
    total = 0
    total_correct = 0
    num_list = []
    for batched_graph, labels, num in dataloader:
        feat = batched_graph.ndata['h']
        total += len(labels)
        logits, hg = model(batched_graph, feat)
        _, predicted = torch.max(logits, 1)
        
        for i in range(len(labels)):
            if num[i] in o_sub_list:
                if labels[i] == predicted[i]:
                    total_correct += 1
                if labels[i] == 4 and predicted[i] == 0:
                    total_correct += 1
                    predicted[i] = 4
            elif num[i] in m_sub_list:
                if labels[i] == predicted[i]:
                    total_correct += 1
                if labels[i] == 4 and predicted[i] == 0:
                    total_correct += 1
                    predicted[i] = 4
                if labels[i] == 3 and predicted[i] == 1:
                    total_correct += 1
                    predicted[i] = 3
            else:
                if labels[i] == 0 and predicted[i] == 0:
                    total_correct += 1
                elif labels[i] == 0 and predicted[i] == 4:
                    total_correct += 1
                elif labels[i] == 1 and predicted[i] == 1:
                    total_correct += 1
                elif labels[i] == 1 and predicted[i] == 3:
                    total_correct += 1
                elif labels[i] == 2 and predicted[i] == 2:
                    total_correct += 1
        num_list.append(num)    
    acc = 1.0 * total_correct / total
    return acc, predicted, labels, hg, num_list

In [209]:
# Training
kf = KFold(n_splits=10, shuffle=True, random_state=0)
fold_num = 1
all_acc_list = []
target_list_cm = []
pred_list_cm = []
hg_all = torch.empty(0)
h3_all = torch.empty(0)
num_val_end = []
wrong_list = []

for train_indices, val_indices in kf.split(dataset):
    
    train_dataset = [dataset[i] for i in train_indices]
    val_dataset = [dataset[i] for i in val_indices]
    
    train_loader = GraphDataLoader(
        train_dataset,
        batch_size=30,
        drop_last=False,
        shuffle=True)
    
    val_loader = GraphDataLoader(
        val_dataset,
        batch_size=30,
        drop_last=False,
        shuffle=True)
    
    
    model = Classifier(74, 74, 5)
    opt = torch.optim.Adam(model.parameters())
    train_acc_list = []
    valid_acc_list = [0.0]
    loss_list = []
    predicted_val_list = []
    labels_val_list = []
    for epoch in range(200):
        for graph, labels, num in train_loader:
            feats = graph.ndata['h']
            logits, hg = model(graph, feats)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
        train_acc, predicted_train, labels_train, hg_t, num_train = evaluate(train_loader, model, ortho_sub_list, meta_sub_list)
        train_acc_list.append(train_acc)
        valid_acc, predicted_val, labels_val, hg_v, num_val = evaluate(val_loader, model, ortho_sub_list, meta_sub_list)
        valid_acc_list.append(valid_acc)
        predicted_val_list.append(predicted_val)
        labels_val_list.append(labels_val)
        loss1 = loss
        loss1 = loss1.detach()
        loss1 = loss1.numpy()
        loss_list.append(loss1)
        if epoch == 199:
            hg_all =torch.cat((hg_all, hg_v), dim=0)
        if max(valid_acc_list[:-1]) < valid_acc_list[-1]:
            h3_all_m = torch.empty(0)
            num_val_end_m = torch.empty(0)
            target_max = []
            pred_max = []
            
            h3_all_m =torch.cat((h3_all_m, hg_v), dim=0)
            num_val_end_m = torch.cat((num_val_end_m, num_val[0]), dim=0)
            target_max.append(labels_val)
            pred_max.append(predicted_val)
            
    h3_all =torch.cat((h3_all_m, h3_all), dim=0)    
    
    print(
        "Fold {:05d} | Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} ".format(
        fold_num, epoch, loss , max(train_acc_list), max(valid_acc_list))
        )
    fold_num += 1
    all_acc_list.append(max(valid_acc_list))
    
    temp1 = labels_val_list[-1].tolist()
    temp2 = predicted_val_list[-1].tolist()
    target_list_cm.extend(temp1)
    pred_list_cm.extend(temp2)
    num_val_end.append(num_val_end_m)
    
average_accuracy = np.mean(all_acc_list)
print("Accuracy after CV:", average_accuracy)

Fold 00001 | Epoch 00199 | Loss 0.1140 | Train Acc. 0.9130 | Validation Acc. 0.7308 
Fold 00002 | Epoch 00199 | Loss 0.3363 | Train Acc. 0.8826 | Validation Acc. 0.8846 
Fold 00003 | Epoch 00199 | Loss 0.1725 | Train Acc. 0.9043 | Validation Acc. 0.8462 
Fold 00004 | Epoch 00199 | Loss 0.1677 | Train Acc. 0.8913 | Validation Acc. 0.8846 
Fold 00005 | Epoch 00199 | Loss 0.2780 | Train Acc. 0.9043 | Validation Acc. 0.9231 
Fold 00006 | Epoch 00199 | Loss 0.3094 | Train Acc. 0.8913 | Validation Acc. 0.8846 
Fold 00007 | Epoch 00199 | Loss 0.5308 | Train Acc. 0.9004 | Validation Acc. 0.8400 
Fold 00008 | Epoch 00199 | Loss 0.2684 | Train Acc. 0.8961 | Validation Acc. 1.0000 
Fold 00009 | Epoch 00199 | Loss 0.2308 | Train Acc. 0.9048 | Validation Acc. 0.8000 
Fold 00010 | Epoch 00199 | Loss 0.0175 | Train Acc. 0.9048 | Validation Acc. 0.9200 
Accuracy after CV: 0.8713846153846154
