In [None]:
import os, torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
import pickle
import lmdb
import pandas as pd
import numpy as np
from rdkit import Chem
from tqdm import tqdm
from rdkit.Chem import AllChem
from rdkit import RDLogger
from sklearn.metrics import roc_auc_score
RDLogger.DisableLog('rdApp.*')  
import warnings
warnings.filterwarnings(action='ignore')
from multiprocessing import Pool
from unimol.models import UniMolModel
from unicore import tasks
from unicore.data import Dictionary
from torch_ema import ExponentialMovingAverage
from unicore import checkpoint_utils, distributed_utils, options, utils
from unicore import models
from unicore.data.sort_dataset import  EpochShuffleDataset
from unicore.data import LMDBDataset, EpochBatchIterator, CountingIterator, UnicoreDataset
from unicore.data.data_utils import collate_tokens, collate_tokens_2d, batch_by_size
from unicore.data.sort_dataset import  EpochShuffleDataset
from torch import nn
from preprocessing import inner_smi2coords
from sklearn.model_selection import train_test_split
from rdkit.Chem import Draw, PandasTools

In [None]:
cyp_col = ['BOM_1A2', 'BOM_2A6', 'BOM_2B6', 'BOM_2C8', 'BOM_2C9', 'BOM_2C19', 'BOM_2D6', 'BOM_2E1', 'BOM_3A4',]
cyp2label = {
    'BOM_1A2' : 0,
    'BOM_2A6' : 1,
    'BOM_2B6' : 2,
    'BOM_2C8' : 3,
    'BOM_2C9' : 4,
    'BOM_2C19' : 5,
    'BOM_2D6' : 6,
    'BOM_2E1' : 7,
    'BOM_3A4' : 8
    }

In [None]:
# Reaction 종류 보기
def reaction_unique(df):
    all_reactions = []

    reacts = df[cyp].unique()
    for recs in reacts:
        recs = recs.split('\n')
        for rec in recs:
            if not rec:continue
            all_reactions.append(rec.replace('<', '').replace('>', '').split(';')[-2].strip().lower())
        
    all_reactions = list(set(all_reactions))
    return sorted(all_reactions)

In [None]:
def reaction_type2label_type1(x):
    # 0. Non-reaction
    # 1. Cleavage
    # 2. Hydroxylation
    # 3. Oxidation
    # 4. Others
    reaction_type2label_dict = {
            'cleavage' : 1,
            'dehydrogenation' : 4,
            'denitrosation' : 4,
            'desulfation' : 4,
            'desulfuration' : 4,
            'desulphuration' : 4,
            'epoxidation' : 4,
            'hdroxylation' : 2,
            'hydrlxylation' : 2,
            'hydroxyaltion' : 2,
            'hydroxylation' : 2,
            'hydroxylaton' : 2,
            'n-oxidation' : 4,
            'oxdation' : 3,
            'oxidaiton' : 3,
            'oxidation' : 3,
            'oxidiation' : 3,
            'oxydation' : 3,
            'rearrangement' : 4,
            'redcution' : 4,
            'reduction' : 4,
            'reducttion' : 4,
            'redunction' : 4,
            'reuction' : 4,
            'ruction' : 4,
            's-oxidation' : 4,
            'unspoxidation' : 4,
            'unspreduction' : 4}
    return reaction_type2label_dict[x]

def reaction_type2label_type2(x):
    # 0. Non-reaction
    # 1. Cleavage
    # 2. Hydroxylation | Oxidation
    # 3. Others
    reaction_type2label_dict = {
            'cleavage' : 1,
            'dehydrogenation' : 3,
            'denitrosation' : 3,
            'desulfation' : 3,
            'desulfuration' : 3,
            'desulphuration' : 3,
            'epoxidation' : 3,
            'hdroxylation' : 2,
            'hydrlxylation' : 2,
            'hydroxyaltion' : 2,
            'hydroxylation' : 2,
            'hydroxylaton' : 2,
            'n-oxidation' : 3,
            'oxdation' : 2,
            'oxidaiton' : 2,
            'oxidation' : 2,
            'oxidiation' : 2,
            'oxydation' : 2,
            'rearrangement' : 3,
            'redcution' : 3,
            'reduction' : 3,
            'reducttion' : 3,
            'redunction' : 3,
            'reuction' : 3,
            'ruction' : 3,
            's-oxidation' : 3,
            'unspoxidation' : 3,
            'unspreduction' : 3}
    
    return reaction_type2label_dict[x]
def check_reaction(s_atom_idx, e_atom_idx, s_atom, e_atom, reaction, class_type = 1):
    if not reaction:
        return 0
    rea_idx, reaction_type, _ = reaction[1:-1].split(';')
    rea_sidx, rea_eidx = rea_idx.split(',')                        
                
    rea_sidx = int(rea_sidx)
    if rea_eidx.isalpha(): # 글자면            
        if (rea_sidx-1 == s_atom_idx) and (rea_eidx == e_atom):            
            if class_type  == 1:
                return reaction_type2label_type1(reaction_type.lower())
            elif class_type == 2:
                return reaction_type2label_type2(reaction_type.lower())
    else:
        rea_eidx = int(rea_eidx)
        if (rea_sidx-1 == s_atom_idx) and (rea_eidx+1 == e_atom):
            if class_type  == 1:
                return reaction_type2label_type1(reaction_type.lower())
            elif class_type == 2:
                return reaction_type2label_type2(reaction_type.lower())

    return 0

def mol2bond_label(bonds, bonds_idx, reactions, class_type):
    
    bond_label = [0] * len(bonds)
    
    for reaction in reactions:
        for n in range(len(bonds)):
            s_atom_idx, e_atom_idx, s_atom, e_atom = bonds_idx[n][0], bonds_idx[n][1], bonds[n][0], bonds[n][1]
            
            is_react = check_reaction(s_atom_idx, e_atom_idx, s_atom, e_atom, reaction, class_type)
            bond_label[n] = is_react
    return bond_label

In [None]:
import re

def remove_alpha_dot(text):
    pattern = r'[A-Za-z]\.'  # 알파벳 + 마침표에 대응하는 패턴
    return re.sub(pattern, '', text)


In [None]:
class_type = 1            
for class_type in [1, 2]:
    for cyp in cyp_col:
        df = PandasTools.LoadSDF('../data/EBoMD_train_0213.sdf')
        df = df.reset_index(drop=True) 
        df = df[df[cyp] != ''].reset_index(drop=True)

        train_df, valid_df = train_test_split(df, test_size=0.2)
        train_df, valid_df = train_df.reset_index(drop=True), valid_df.reset_index(drop=True)

        test_df = PandasTools.LoadSDF('../data/EBoMD_test_0213.sdf')
        test_df = test_df.reset_index(drop=True)
        test_df = test_df[test_df[cyp] != ''].reset_index(drop=True)
        test_df = test_df[~test_df[cyp].isna()].reset_index(drop=True)
        test_df[cyp] = test_df[cyp].apply(remove_alpha_dot)
        
        def write_lmdb(df, outpath='./som', save_name = 'train.lmdb', nthreads=16): 
            os.makedirs(outpath, exist_ok=True)
            output_name = os.path.join(outpath, save_name)
            try:
                os.remove(output_name)
            except:
                pass

            env_new = lmdb.open(
                output_name,
                subdir=False,
                readonly=False,
                lock=False,
                readahead=False,
                meminit=False,
                max_readers=1,
                map_size=int(100e9),
            )
            txn_write = env_new.begin(write=True)
            i = 0
            for idx in range(df.shape[0]):
                # try:      
                reactions = df.loc[idx, cyp].split('\n')
                
                mol = df['ROMol'][idx]
                mol_h = AllChem.AddHs( mol, addCoords=True)
                smile =  Chem.MolToSmiles(mol)
                
                atoms = [i.GetSymbol() for i in AllChem.AddHs( mol, addCoords=True).GetAtoms()]
                bonds_idx = [(i.GetBeginAtomIdx(),i.GetEndAtomIdx()) for i in AllChem.AddHs( mol, addCoords=True).GetBonds()]
                bonds = [(i.GetBeginAtom().GetSymbol(),i.GetEndAtom().GetSymbol()) for i in AllChem.AddHs( mol, addCoords=True).GetBonds()]    

                bond_label = mol2bond_label(bonds, bonds_idx, reactions, class_type)  
                inner_output = inner_smi2coords(mol, smile, bond_label, False)
                if inner_output is not None:
                    txn_write.put(f'{i}'.encode("ascii"), inner_output)
                    i += 1                
            print('{} process {} lines'.format(output_name, i))
            txn_write.commit()
            env_new.close()

        write_lmdb(train_df, outpath='./som', save_name=f'train_{cyp}_class_type{class_type}.lmdb' , nthreads=8)
        write_lmdb(valid_df, outpath='./som', save_name=f'valid_{cyp}_class_type{class_type}.lmdb' , nthreads=8)
        write_lmdb(test_df, outpath='./som', save_name=f'test_{cyp}_class_type{class_type}.lmdb' , nthreads=8)