In [2]:
import pandas as pd
import argparse
import torch, os, random
import numpy as np
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from torch_ema import ExponentialMovingAverage
from matplotlib import pyplot as plt
from glob import glob
from modules.dualgraph.mol import smiles2graphwithface
from modules.dualgraph.gnn import GNN
from modules.dualgraph.dataset import DGData

from torch.nn.utils import clip_grad_norm_

import torch_geometric
from torch_geometric.data import Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
import dgl
import warnings

import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import os

from sklearn.preprocessing import minmax_scale
from EdgeShaper.edgeshaper import batch_edgeshaper
from collections import defaultdict
import os
from PIL import Image, ImageDraw, ImageFont

from rdkit.Chem import rdFMCS
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem, Draw, rdMolAlign, rdDepictor

from IPython.display import SVG
import seaborn as sns
from modules.equevlalent import get_equivalent_bonds
from copy import deepcopy
warnings.filterwarnings("ignore")

In [3]:
class CustomDataset(InMemoryDataset):
    def __init__(self, root='dataset_path', transform=None, pre_transform=None, df=None, target_type='MLM', mode='train'):
        self.df = df
        self.target_type = target_type
        self.mode = mode
        super().__init__(root, transform, pre_transform, df)
        

    @property
    def raw_file_names(self):        
        return [f'raw_{i+1}.pt' for i in range(self.df.shape[0])]

    @property
    def processed_file_names(self):
        return [f'data_{i+1}.pt' for i in range(self.df.shape[0])]        

    def len(self):
        return len(self.graph_list)

    def get(self, idx):        
        return self.graph_list[idx]

    def process(self):        
        smiles_list = self.df["SMILES"].values
        targets_list = self.df[['MLM', 'HLM']].values
        test_id_list = self.df['id'].values
        data_list = []
        for i in range(len(smiles_list)):
            data = DGData()
            smiles = smiles_list[i]
            targets = targets_list[i]
            graph = smiles2graphwithface(smiles)

            data.__num_nodes__ = int(graph["num_nodes"])
            data.edge_index = torch.from_numpy(graph["edge_index"]).to(torch.int64)
            data.edge_attr = torch.from_numpy(graph["edge_feat"]).to(torch.int64)
            data.x = torch.from_numpy(graph["node_feat"]).to(torch.int64)
            data.y = torch.Tensor([targets])

            data.ring_mask = torch.from_numpy(graph["ring_mask"]).to(torch.bool)
            data.ring_index = torch.from_numpy(graph["ring_index"]).to(torch.int64)
            data.nf_node = torch.from_numpy(graph["nf_node"]).to(torch.int64)
            data.nf_ring = torch.from_numpy(graph["nf_ring"]).to(torch.int64)
            data.num_rings = int(graph["num_rings"])
            data.n_edges = int(graph["n_edges"])
            data.n_nodes = int(graph["n_nodes"])
            data.n_nfs = int(graph["n_nfs"])        
            data.smile = smiles
            data.id = test_id_list[i]
            

            data_list.append(data)
        self.smiles_list = smiles_list  
        self.graph_list = data_list
        self.targets_list = targets_list

In [4]:
test_df = pd.read_csv('data/test_paper.csv')
test_df['MLM_raw'], test_df['HLM_raw'] = test_df['MLM_raw'].str.replace('<', '').str.replace('>', ''), test_df['HLM_raw'].str.replace('<', '').str.replace('>', '')
test_df['MLM'], test_df['HLM'] = test_df['MLM_raw'].astype(float), test_df['HLM_raw'].astype(float)
test_df.loc[test_df['HLM'] > 100, 'HLM'] = 100.0

In [5]:
stable_df = test_df[test_df['HLM'] >= 50].reset_index(drop=True)
unstable_df = test_df[test_df['HLM'] < 50].reset_index(drop=True)

print(stable_df.shape, unstable_df.shape)

(288, 14) (195, 14)


In [6]:
stable_dataset = CustomDataset(df = stable_df, mode='test', target_type='MLM')
stable_loader = DataLoader(stable_dataset, batch_size=1, shuffle=False, num_workers = 8) 

unstable_dataset = CustomDataset(df = unstable_df, mode='test', target_type='MLM')
unstable_loader = DataLoader(unstable_dataset, batch_size=1, shuffle=False, num_workers = 8) 

Processing...
Done!
Processing...
Done!


In [10]:
class MetaboGNN(torch.nn.Module):

    def __init__(self, mode):
        super(MetaboGNN, self).__init__()
        self.mode = mode
        self.ddi = True
        self.gnn = GNN(mlp_hidden_size = 512, mlp_layers = 2, latent_size = 128, use_layer_norm = False,
                        use_face=True, ddi=self.ddi, dropedge_rate = 0.1, dropnode_rate = 0.1, dropout = 0.1,
                        dropnet = 0.1, global_reducer = "sum", node_reducer = "sum", face_reducer = "sum", graph_pooling = "sum",                        
                        node_attn = True, face_attn = True)
        if self.mode != 'Scratch':
            state_dict=  torch.load('GraphCL/gnn_pretrain.pt', map_location='cpu')
            self.gnn.load_state_dict(state_dict, strict=False)

        self.fc1 = nn.Sequential(
                    nn.LayerNorm(128),
                    nn.Linear(128, 128,),
                    nn.BatchNorm1d(128),
                    nn.Dropout(0.1),
                    nn.ReLU(),
                    nn.Linear(128, 1),
                    )
        self.fc2 = nn.Sequential(
                    nn.LayerNorm(128),
                    nn.Linear(128, 128,),
                    nn.BatchNorm1d(128),
                    nn.Dropout(0.1),
                    nn.ReLU(),
                    nn.Linear(128, 1),
                    )

        self.fc1[-1].weight.data.normal_(mean=0.0, std=0.01)
        self.fc2[-1].weight.data.normal_(mean=0.0, std=0.01)

    def forward(self, batch):
        mol = self.gnn(batch)

        out1 = torch.sigmoid(self.fc1(mol).squeeze(1)) * 100        
        if self.mode == 'MetaboGNN':
            out2 = (torch.sigmoid(self.fc2(mol).squeeze(1))-0.5) * 200
        else:
            out2 = torch.sigmoid(self.fc1(mol).squeeze(1)) * 100        

        return out1, out2

In [11]:
device = 'cuda:0'

In [18]:
model = MetaboGNN(mode = 'MetaboGNN').to(device)
model.load_state_dict(torch.load('ckpt/2025_MetaboGNN.pt'))
model.eval()

MetaboGNN(
  (gnn): GNN(
    (encoder_edge): MLPwoLastAct(
      (module_list): ModuleList(
        (0): Linear(in_features=13, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=128, bias=True)
      )
    )
    (encoder_node): MLPwoLastAct(
      (module_list): ModuleList(
        (0): Linear(in_features=174, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=128, bias=True)
      )
    )
    (encoder_face): MLPwoLastAct(
      (module_list): ModuleList(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=128, bias=True)
      )
    )
    (gnn_layers): M

In [13]:
from rdkit import Chem

def extract_bond_fragments(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    fragments = []

    for bond in mol.GetBonds():
        # Determine bond type representation
        bond_type = ":" if bond.GetIsAromatic() else (
            "=" if bond.GetBondTypeAsDouble() == 2.0 else (
                "-" if bond.GetBondTypeAsDouble() == 1.0 else "#"
            )
        )

        begin_atom = bond.GetBeginAtom()
        end_atom = bond.GetEndAtom()

        begin_atom_symbol = begin_atom.GetSymbol().lower() if begin_atom.GetIsAromatic() else begin_atom.GetSymbol().upper()
        end_atom_symbol = end_atom.GetSymbol().lower() if end_atom.GetIsAromatic() else end_atom.GetSymbol().upper()

        bond_representation = f"{begin_atom_symbol}{bond_type}{end_atom_symbol}"

        neighbors_repr = []

        for atom in [begin_atom, end_atom]:
            atom_neighbors = []
            for neighbor in atom.GetNeighbors():
                if neighbor.GetIdx() != begin_atom.GetIdx() and neighbor.GetIdx() != end_atom.GetIdx():
                    neighbor_bond = mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
                    neighbor_bond_type = ":" if neighbor_bond.GetIsAromatic() else (
                        "=" if neighbor_bond.GetBondTypeAsDouble() == 2.0 else (
                            "-" if neighbor_bond.GetBondTypeAsDouble() == 1.0 else "#"
                        )
                    )
                    neighbor_symbol = neighbor.GetSymbol().lower() if neighbor.GetIsAromatic() else neighbor.GetSymbol().upper()
                    atom_neighbors.append(
                        f"{atom.GetSymbol().lower() if atom.GetIsAromatic() else atom.GetSymbol().upper()}{neighbor_bond_type}{neighbor_symbol}"
                    )
            if atom_neighbors:
                neighbors_repr.append(",".join(atom_neighbors))

        fragment_str = f"{bond_representation}\n{';'.join(neighbors_repr)}"
        fragments.append(fragment_str)

    return fragments

In [14]:
rng = np.random.default_rng(seed=42)

In [15]:
class EdgeShaperDataset(InMemoryDataset):
    def __init__(self, root='dataset_path', transform=None, pre_transform=None, df=None, target_type='MLM', mode='train', symmetric_pairs=None):        
        self.df = df
        self.target_type = target_type
        self.mode = mode
        self.symmetric_pairs = symmetric_pairs
        super().__init__(root, transform, pre_transform, df)
        

    @property
    def raw_file_names(self):        
        return [f'raw_{i+1}.pt' for i in range(self.df.shape[0])]

    @property
    def processed_file_names(self):
        return [f'data_{i+1}.pt' for i in range(self.df.shape[0])]        

    def len(self):
        return len(self.graph_list)

    def get(self, idx):
        graph = self.graph_list[idx]
        j= self.df.loc[idx, 'j']        

        num_nodes = graph['n_nodes']
        num_edges = graph['n_edges']

        max_num_edges = num_nodes*(num_nodes-1)
        graph_density = num_edges/max_num_edges
        # P = graph_density
        P = 0.7

        E_z_mask = rng.binomial(1, P, num_edges)
        E_mask = torch.ones(num_edges)
        pi = torch.randperm(num_edges)

        E_j_plus_index = torch.ones(num_edges, dtype=torch.int)
        E_j_minus_index = torch.ones(num_edges, dtype=torch.int)
        selected_edge_index = np.where(pi == j)[0].item()

        # dictionary로 변환하여 쉽게 참조
        symmetric_dict = {}
        for pair in self.symmetric_pairs:
            symmetric_dict[pair[0]] = pair[1]
            symmetric_dict[pair[1]] = pair[0]

        E_j_plus_index = torch.ones(num_edges, dtype=torch.int)
        E_j_minus_index = torch.ones(num_edges, dtype=torch.int)
        selected_edge_index = np.where(pi == j)[0].item()

        # E_j_plus_index 처리
        for k in range(num_edges):
            current_edge = pi[k].item()
            if k <= selected_edge_index:
                mask_value = E_mask[current_edge]
            else:
                mask_value = E_z_mask[current_edge]
            
            E_j_plus_index[current_edge] = mask_value
            # symmetric한 edge가 있다면 같은 mask 적용
            if current_edge in symmetric_dict:
                E_j_plus_index[symmetric_dict[current_edge]] = mask_value

        # E_j_minus_index 처리
        for k in range(num_edges):
            current_edge = pi[k].item()
            if k < selected_edge_index:
                mask_value = E_mask[current_edge]
            else:
                mask_value = E_z_mask[current_edge]
            
            E_j_minus_index[current_edge] = mask_value
            # symmetric한 edge가 있다면 같은 mask 적용
            if current_edge in symmetric_dict:
                E_j_minus_index[symmetric_dict[current_edge]] = mask_value

        retained_indices_plus = torch.LongTensor(torch.nonzero(E_j_plus_index).tolist()).squeeze()
        retained_indices_minus = torch.LongTensor(torch.nonzero(E_j_minus_index).tolist()).squeeze()

        if len(retained_indices_plus.shape) == 0:
            retained_indices_plus = retained_indices_plus.unsqueeze(0)
        if len(retained_indices_minus.shape) == 0:
            retained_indices_minus = retained_indices_minus.unsqueeze(0)    

        plus_graph = graph.clone()
        minus_graph = graph.clone()

        plus_graph.ring_index = graph.ring_index[:, retained_indices_plus]
        minus_graph.ring_index = graph.ring_index[:, retained_indices_minus]

        plus_graph.edge_attr = graph.edge_attr[retained_indices_plus]
        minus_graph.edge_attr = graph.edge_attr[retained_indices_minus]

        plus_graph.edge_index = graph.edge_index[:, retained_indices_plus]
        minus_graph.edge_index = graph.edge_index[:, retained_indices_minus]

        plus_graph.num_edges = retained_indices_plus.shape[0]
        minus_graph.num_edges = retained_indices_minus.shape[0]

        plus_graph.n_edges = retained_indices_plus.shape[0]
        minus_graph.n_edges = retained_indices_minus.shape[0]

        return plus_graph, minus_graph

    def process(self):
        smiles_list = self.df["SMILES"].values
        targets_list = self.df[['MLM', 'HLM']].values
        test_id_list = self.df['id'].values
        self.mol_list = []
        data_list = []
        for i in range(len(smiles_list)):
            data = DGData()
            smiles = smiles_list[i]
            self.mol_list.append(Chem.MolFromSmiles(smiles))
            targets = targets_list[i]
            graph = smiles2graphwithface(smiles)

            data.__num_nodes__ = int(graph["num_nodes"])
            data.edge_index = torch.from_numpy(graph["edge_index"]).to(torch.int64)
            data.edge_attr = torch.from_numpy(graph["edge_feat"]).to(torch.int64)
            data.x = torch.from_numpy(graph["node_feat"]).to(torch.int64)
            data.y = torch.Tensor([targets])

            data.ring_mask = torch.from_numpy(graph["ring_mask"]).to(torch.bool)
            data.ring_index = torch.from_numpy(graph["ring_index"]).to(torch.int64)
            data.nf_node = torch.from_numpy(graph["nf_node"]).to(torch.int64)
            data.nf_ring = torch.from_numpy(graph["nf_ring"]).to(torch.int64)
            data.num_rings = int(graph["num_rings"])
            data.n_edges = int(graph["n_edges"])
            data.n_nodes = int(graph["n_nodes"])
            data.n_nfs = int(graph["n_nfs"])        
            data.smile = smiles
            data.id = test_id_list[i]

            data_list.append(data)
        self.smiles_list = smiles_list  
        self.graph_list = data_list
        self.targets_list = targets_list
   

In [16]:
def get_edge_shape_df(test_df, num_edges, M):
    edge_shape_df =[]
    for j in range(num_edges):
        dd = test_df.loc[[i]].sample(M , replace=True)
        dd['j'] = j
        edge_shape_df.append(dd)
    edge_shape_df = pd.concat(edge_shape_df).reset_index(drop=True)
    return edge_shape_df

In [22]:
os.makedirs('EdgeShaper/scores',exist_ok=True)

In [31]:
M = 100
fragments_df = []

for i in tqdm(range(test_df.shape[0])):
    smile = test_df.loc[i, 'SMILES']
    mol = Chem.MolFromSmiles(smile)
    _, symmetric_pairs = get_equivalent_bonds(mol)
    graph = smiles2graphwithface(smile)
    num_edges = graph['edge_feat'].shape[0]

    edge_shape_df = get_edge_shape_df(test_df, num_edges, M)
    
    edgeshaper_dataset = EdgeShaperDataset(df = edge_shape_df, mode='test', target_type='HLM', symmetric_pairs=symmetric_pairs)
    edgeshaper_loader = DataLoader(edgeshaper_dataset, batch_size=M, shuffle=False, num_workers = 8) 

    stable_edges_explanations = []

    for plus_batch, minus_batch in edgeshaper_loader:                        
        with torch.no_grad():
            plus_mlm, plus_res = model(plus_batch.to(device))
            plus_hlm = plus_mlm - plus_res

            minus_mlm, minus_res = model(minus_batch.to(device))
            minus_hlm = minus_mlm - minus_res    
        
        plus_stable, minus_stable = (plus_hlm-50).abs(), (minus_hlm-50).abs()         
        stability_impact = plus_stable - minus_stable
                
        avg_stability_impact = stability_impact.mean()        
        stability_weighted_contrib = avg_stability_impact.item() #* np.sign(marginal_contrib)        
        stable_edges_explanations.append(stability_weighted_contrib)        

        
    stable_edges_explanations = np.array(stable_edges_explanations)

    plus_res
    test_id = test_df.loc[i, 'id']

    fragments = extract_bond_fragments(smile)

    for n, fragment in  enumerate(fragments):
        stable_score = stable_edges_explanations[n*2] +  stable_edges_explanations[n*2+1]    
        fragments_df.append({'fragment' : fragment, 'stable_score' : stable_score, 'mol_idx' : i})
    
    np.save(f'EdgeShaper/scores/{test_id}.npy', stable_edges_explanations)    

In [33]:
fragments_df = pd.DataFrame(fragments_df)
fragments_df.to_csv('fragments_df.csv', index=None)

In [37]:
fragments_df.head()

Unnamed: 0,fragment,stable_score,mol_idx
0,"C-C\nC-C,C-N",0.202087,0
1,"C-C\nC-C,C-N",-0.019427,0
2,"C-N\nC-C,C-C;N-c",-0.4554,0
3,"N-c\nN-C;c:c,c:n",0.091446,0
4,"c:c\nc-N,c:n;c:c",-0.557052,0
