In [1]:
import os
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
# import pymol
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


# 初始化 PyMOL
# pymol.finish_launching(['pymol', '-cq'])

In [2]:
ec_site_dataset_path = '../dataset/ec_site_dataset/'

In [3]:
def get_dataset(max_sample, ec_site_dataset_path=ec_site_dataset_path,  
                merge_dataset_name_str='uniprot_ecreact_merge_dataset_limit_', 
                sub_set=['train', 'valid', 'test']):
    
    uniprot_ecreact_merge_dataset_path = os.path.join(ec_site_dataset_path, f'{merge_dataset_name_str}{max_sample}')
    dataset = pd.DataFrame()

    for dataset_flag in sub_set:
        folder_path = os.path.join(uniprot_ecreact_merge_dataset_path, f'{dataset_flag}_dataset')
        csv_fnames = os.listdir(folder_path)
        pbar = tqdm(
            csv_fnames,
            total=len(csv_fnames),
            desc=f'{dataset_flag}'
        )


        for fname in pbar:
            df = pd.read_csv(os.path.join(folder_path, fname))
            # df = df[['alphafolddb-id', 'aa_sequence']]
            dataset = pd.concat([dataset, df])
            

    # info_df = info_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence']).reset_index(drop=True)
    return dataset

# def extract_ligand_with_pymol(pdb_file):
#     try:
#         pymol.cmd.load(pdb_file, 'complex')
#         pymol.cmd.select('ligand', 'organic')
#         ligand_pdbblock = pymol.cmd.get_pdbstr('ligand')
#         mol = Chem.MolFromPDBBlock(ligand_pdbblock)
#         smiles = Chem.MolToSmiles(mol)
#         pymol.cmd.delete('all') 
#     except:
#         pymol.cmd.delete('all') 
#         return ''


#     return smiles

def extract_hetatm_from_pdb(pdb_file):
    with open(pdb_file, 'r', encoding='utf-8') as f:
        lines = [x.strip() for x in f.readlines()]
    
    hetatm_lines = []
    for line in lines:
        if line.startswith('HETATM') and 'HOH' not in line:
        # if 'HETATM' in line:
            hetatm_lines.append(line)
    
    return '\n'.join(hetatm_lines) + '\nEND\n'

def extract_ligand_with_rdkit(pdb_file):
    hetatm_pdbblock = extract_hetatm_from_pdb(pdb_file)
    mol = Chem.MolFromPDBBlock(hetatm_pdbblock)
    if mol is None:
        return ''
    else:
        return Chem.MolToSmiles(mol)
    
def get_reactants(rxn):
    rxn_smiles = get_rxn_smiles(rxn)
    reactants = rxn_smiles.split('>>')[0].split('.')
    reactants = [canonicalize_smiles(x) for x in reactants]
    return reactants

def get_rxn_smiles(rxn_aa_str):
    precursors, products = rxn_aa_str.split('>>')
    reactants, _ = precursors.split('|')
    rxn_smiles = f'{reactants}>>{products}'
    return rxn_smiles

def canonicalize_smiles(smi):
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        return Chem.MolToSmiles(mol)
    else:
        return ''

def check_ligand_in_reactants(ligands, reactants):
    ligands_arr = np.array(ligands)
    reactants_arr = np.array(reactants)
    string_arr = np.zeros(len(reactants), dtype=bool)
    for item in ligands:
        if item in reactants_arr:
            idx = np.where(reactants_arr == item)[0][0]
            string_arr[idx] = True
    all_present = np.all(string_arr)
    have_present = np.any(string_arr)
    return all_present, have_present, string_arr

def check_all_reactants_in_pdb(rxn, pdb_ids, pdb_files_folder):
    # rxn_smiles = get_rxn_smiles(rxn)
    # reactants = rxn_smiles.split('>>')[0].split('.')
    # reactants = [canonicalize_smiles(x) for x in reactants]
    reactants = get_reactants(rxn)
    if '' in reactants:
        return False
    pdb_id_list = [x for x in pdb_ids.split(';') if x]
    
    ligands = set()
    for pdb_id in pdb_id_list:
        # ligand_smiles = extract_ligand_with_pymol(os.path.join(pdb_files_folder, '{}.pdb'.format(pdb_id.lower())))
        pdb_file = os.path.join(pdb_files_folder, '{}.pdb'.format(pdb_id.lower()))
        if not os.path.exists(pdb_file): continue
        ligand_smiles = extract_ligand_with_rdkit(pdb_file)
        if ligand_smiles:
            ligands.update(set(ligand_smiles.split('.')))
            
    ligands = list(ligands)
    
    all_reactants_in_pdb, have_reactants_in_pdb, present_arr = check_ligand_in_reactants(ligands, reactants)
    return all_reactants_in_pdb.item(), have_reactants_in_pdb.item()

In [4]:
max_sample = 100
all_pdb_dataset_train_val_save_path = os.path.join(ec_site_dataset_path, 'uniprot_ecreact_cluster_split_merge_dataset_limit_{}'.format(max_sample), 'all_pdb_dataset_train_val.csv')
if not os.path.exists(all_pdb_dataset_train_val_save_path):
    train_dataset_df = get_dataset(max_sample=max_sample, ec_site_dataset_path=ec_site_dataset_path, merge_dataset_name_str='uniprot_ecreact_cluster_split_merge_dataset_limit_', sub_set=['train'])
    valid_dataset_df = get_dataset(max_sample=max_sample, ec_site_dataset_path=ec_site_dataset_path, merge_dataset_name_str='uniprot_ecreact_cluster_split_merge_dataset_limit_', sub_set=['valid'])
    all_pdb_dataset_train_val = pd.concat([train_dataset_df[~pd.isna(train_dataset_df['pdb-id'])], valid_dataset_df[~pd.isna(valid_dataset_df['pdb-id'])]], axis=0)
    all_pdb_dataset_train_val = all_pdb_dataset_train_val.drop_duplicates(subset=['reaction', 'pdb-id', 'alphafolddb-id', 'aa_sequence']).reset_index(drop=True)
    # all_pdb_dataset_train_val['pdb-id'] = all_pdb_dataset_train_val['pdb-id'].apply(lambda x:[y for y in x.split(';') if y])
    # all_pdb_dataset_train_val = all_pdb_dataset_train_val.explode('pdb-id').drop_duplicates(subset=['reaction', 'pdb-id', 'alphafolddb-id', 'aa_sequence']).reset_index(drop=True)
    all_pdb_dataset_train_val.to_csv(all_pdb_dataset_train_val_save_path, index=False)
else:
    all_pdb_dataset_train_val = pd.read_csv(all_pdb_dataset_train_val_save_path)

all_pdb_dataset_train_val

Unnamed: 0,reaction,ec,pdb-id,alphafolddb-id,aa_sequence,site_labels,site_types
0,CC(=O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)(C)CO...,1.1.1.36,1UZL;1UZM;1UZN;2NTN;,P9WGT3;,MTATATEGAKPPFVSRSVLVTGGNRGIGLAIAQRLAADGHKVAVTH...,"[[25, 27], [47], [61, 62], [90], [153], [157],...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]"
1,CC[C@@H](O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)...,1.1.1.36,1UZL;1UZM;1UZN;2NTN;,P9WGT3;,MTATATEGAKPPFVSRSVLVTGGNRGIGLAIAQRLAADGHKVAVTH...,"[[25, 27], [47], [61, 62], [90], [153], [157],...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]"
2,CC[C@@H](O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)...,1.1.1.36,1UZL;1UZM;1UZN;2NTN;,P9WGT3;,MTATATEGAKPPFVSRSVLVTGGNRGIGLAIAQRLAADGHKVAVTH...,"[[25, 27], [47], [61, 62], [90], [153], [157],...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]"
3,CC(=O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)(C)CO...,1.1.1.36,1UZL;1UZM;1UZN;2NTN;,P9WGT3;,MTATATEGAKPPFVSRSVLVTGGNRGIGLAIAQRLAADGHKVAVTH...,"[[25, 27], [47], [61, 62], [90], [153], [157],...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]"
4,CCCCCCCCCCCC(=O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O...,1.1.1.36,1UZL;1UZM;1UZN;2NTN;,P9WGT3;,MTATATEGAKPPFVSRSVLVTGGNRGIGLAIAQRLAADGHKVAVTH...,"[[25, 27], [47], [61, 62], [90], [153], [157],...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 2]"
...,...,...,...,...,...,...,...
13227,Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)(O)OP(=O)(O...,4.1.1.49,1YTM;1YVY;,O09460;,MSLSESLAKYGITGATNIVHNPSHEELFAAETQASLEGFEKGTVTE...,"[[60], [200], [206], [206], [206], [225], [225...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
13228,O=O.Oc1ccc(O)nc1|MPVSNAQLTQMFEHVLKLSRVDETQSVAV...,1.13.11.9,7CN3;7CNT;7CUP;,Q88FY1;,MPVSNAQLTQMFEHVLKLSRVDETQSVAVLKSHYSDPRTVNAAMEA...,"[[265], [318], [320]]","[0, 0, 0]"
13229,O|MSKRPNFLVIVADDLGFSDIGAFGGEIATPNLDALAIAGLRLTD...,3.1.6.1,1HDH;4CXK;4CXS;4CXU;4CYR;4CYS;5AJ9;,P51691;,MSKRPNFLVIVADDLGFSDIGAFGGEIATPNLDALAIAGLRLTDFH...,"[[13], [14], [51], [317], [318], [51], [115]]","[0, 0, 0, 0, 0, 1, 1]"
13230,Nc1ccn([C@@H]2O[C@H](COP(=O)([O-])OP(=O)([O-])...,2.7.7.39,1COZ;1N1D;,P27623;,MKKVITYGTFDLLHWGHIKLLERAKQLGDYLVVAISTDEFNLQKQK...,"[[9, 10], [14, 17], [44], [46], [77], [113, 120]]","[0, 0, 0, 0, 0, 0]"


In [5]:
if not os.path.exists(all_pdb_dataset_train_val_save_path.replace('.csv', '_check_reactants.csv')):
    from pandarallel import pandarallel
    pandarallel.initialize(nb_workers=12, progress_bar=True)
    # tqdm.pandas()

    # all_pdb_dataset_train_val = all_pdb_dataset_train_val[:1000]
    # all_pdb_dataset_train_val['all_reactants_in_pdb'], all_pdb_dataset_train_val['have_reactants_in_pdb'] = map(list, zip(*all_pdb_dataset_train_val.progress_apply(lambda row:check_all_reactants_in_pdb(row, os.path.join(ec_site_dataset_path, 'structures/pdb_download')), axis=1)))
    all_pdb_dataset_train_val['all_reactants_in_pdb'], all_pdb_dataset_train_val['have_reactants_in_pdb'] = map(list, zip(*all_pdb_dataset_train_val.parallel_apply(lambda row:check_all_reactants_in_pdb(row['reaction'], row['pdb-id'], os.path.join(ec_site_dataset_path, 'structures/pdb_download')), axis=1)))
    # all_pdb_dataset_train_val['all_reactants_in_pdb'] = all_pdb_dataset_train_val.apply(lambda row:check_all_reactants_in_pdb(row, pdb_files_folder=os.path.join(ec_site_dataset_path, 'structures/pdb_download')), axis=1)
    all_pdb_dataset_train_val.to_csv(all_pdb_dataset_train_val_save_path.replace('.csv', '_check_reactants.csv'), index=False)
else:
    all_pdb_dataset_train_val = pd.read_csv(all_pdb_dataset_train_val_save_path.replace('.csv', '_check_reactants.csv'))
print()
print(all_pdb_dataset_train_val['all_reactants_in_pdb'].sum())
print(all_pdb_dataset_train_val['have_reactants_in_pdb'].sum())


7
77


In [6]:
all_pdb_dataset_train_val.loc[all_pdb_dataset_train_val['all_reactants_in_pdb']]

Unnamed: 0,reaction,ec,pdb-id,alphafolddb-id,aa_sequence,site_labels,site_types,all_reactants_in_pdb,have_reactants_in_pdb
3976,OC[C@H]1O[C@H](O)[C@H](O)[C@@H](O)[C@@H]1O|MNY...,5.3.1.5,1GW9;1MNZ;1O1H;1OAD;1XIB;1XIC;1XID;1XIE;1XIF;1...,P24300;,MNYQPTPEDRFTFGLWTVGWQGRDPFGDATRRALDPVESVRRLAEL...,"[[181], [217], [217], [220], [245], [255], [25...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 1]",True,True
3977,OC[C@H]1O[C@H](O)[C@H](O)[C@@H](O)[C@@H]1O|MSY...,5.3.1.5,1MUW;1S5M;1S5N;1XYA;1XYB;1XYC;1XYL;1XYM;2GYI;,P15587;,MSYQPTPEDRFTFGLWTVGWQGRDPFGDATRPALDPVETVQRLAEL...,"[[181], [217], [217], [220], [245], [255], [25...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 1]",True,True
4199,C[C@@H]1O[C@@H](O)[C@@H](O)[C@H](O)[C@@H]1O|MV...,5.1.3.29,2WCU;,Q8R2K1;,MVALKGIPKVLSPELLFALARMGHGDEIVLADANFPTSSICQCGPV...,"[[32], [79], [120], [138], [140], [24], [69], ...","[0, 0, 0, 0, 0, 1, 1, 1]",True,True
5943,OC[C@H]1O[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O|MAS...,5.1.3.3,1SNZ;1SO0;,Q96C23;,MASVTRAVFGELPSGGGTVEKFQLQSDLLRVDIISWGCTITALEVK...,"[[81, 82], [107], [176, 178], [243], [279], [3...","[0, 0, 0, 0, 0, 0, 1, 1]",True,True
10146,O[C@@H]1CO[C@@H](O)[C@H](O)[C@@H]1O|MKKHGILNSH...,5.4.99.62,1OGC;1OGD;1OGE;1OGF;,P36946;,MKKHGILNSHLAKILADLGHTDKIVIADAGLPVPDGVLKIDLSLKP...,"[[28], [98], [120, 122], [20]]","[0, 0, 0, 1]",True,True
10461,OC[C@H]1O[C@H](O)[C@H](O)[C@@H](O)[C@H]1O|MARA...,3.2.1.22,1UAS;,Q9FXT4;,MARASSSSSPPSPRLLLLLLVAVAATLLPEAAALGNFTAESRGARW...,"[[71], [106], [107], [156], [183], [185], [219...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]",True,True
10463,OC[C@H]1O[C@H](O)[C@H](O)[C@@H](O)[C@H]1O|MQLR...,3.2.1.22,1R46;1R47;3GXN;3GXP;3GXT;3HG2;3HG3;3HG4;3HG5;3...,P06280;,MQLRNPELHLGCALALRFLALVSWDIPGARALDNGLARTPTMGWLH...,"[[203, 207], [170], [231]]","[0, 1, 1]",True,True


In [7]:
all_pdb_dataset_train_val.loc[all_pdb_dataset_train_val['have_reactants_in_pdb']]

Unnamed: 0,reaction,ec,pdb-id,alphafolddb-id,aa_sequence,site_labels,site_types,all_reactants_in_pdb,have_reactants_in_pdb
131,CC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@@H]3[C@@H]2CC[...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
132,CC(=O)O[C@H]1CC[C@]2(C)C(=CC[C@@H]3[C@@H]2CC[C...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
134,CC(=O)Oc1ccc([N+](=O)[O-])cc1.O|MWLRAFILATLSAS...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
184,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True
187,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True
...,...,...,...,...,...,...,...,...,...
11187,O.[Cu+2].[O-][As](O)O|MSDTINLTRRGFLKVSGSGVAVAA...,1.20.9.1,1G8J;1G8K;,Q7SIF3;,MSDTINLTRRGFLKVSGSGVAVAATLSPIASANAQKAPADAGRTTL...,"[[102], [104], [120], [123]]","[0, 0, 0, 0]",False,True
11256,COc1cc(C=CC(=O)SCCNC(=O)CCNC(=O)C(O)C(C)(C)COP...,2.3.1.248,6LPW;,O80467;,MPIHIGSSIPLMVEKMLTEMVKPSKHIPQQTLNLSTLDNDPYNEVI...,"[[47], [169], [294], [316], [378], [169], [391]]","[0, 0, 0, 0, 0, 1, 1]",False,True
11787,O=C1O[C@H]([C@@H](O)CO)C([O-])=C1O.OO|MTTAVRLL...,1.11.1.5,1A2F;1A2G;1AA4;1AC4;1AC8;1AEB;1AED;1AEE;1AEF;1...,P00431;,MTTAVRLLPSLGRTAHKRSLYLFSAAAAAAAAATFAYSQSQKRSSS...,"[[242], [119], [258], [115]]","[0, 1, 1, 2]",False,True
13155,COC(=O)c1cc(O)cc(OC)c1C(=O)c1c([O-])cc(C)cc1O....,1.11.1.-,6KMM;6KMN;7DLK;,P39597;,MSDEQKKPEQIHRRDILKWGAMAGAAVAIGASGLGGLAPLVQTAAK...,"[[241, 243], [326], [339], [240]]","[0, 0, 0, 1]",False,True


In [8]:
have_reactants_in_pdb_dataset_train_val = all_pdb_dataset_train_val.loc[all_pdb_dataset_train_val['have_reactants_in_pdb']].reset_index(drop=True)
have_reactants_in_pdb_dataset_train_val

Unnamed: 0,reaction,ec,pdb-id,alphafolddb-id,aa_sequence,site_labels,site_types,all_reactants_in_pdb,have_reactants_in_pdb
0,CC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@@H]3[C@@H]2CC[...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
1,CC(=O)O[C@H]1CC[C@]2(C)C(=CC[C@@H]3[C@@H]2CC[C...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
2,CC(=O)Oc1ccc([N+](=O)[O-])cc1.O|MWLRAFILATLSAS...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True
3,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True
4,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True
...,...,...,...,...,...,...,...,...,...
72,O.[Cu+2].[O-][As](O)O|MSDTINLTRRGFLKVSGSGVAVAA...,1.20.9.1,1G8J;1G8K;,Q7SIF3;,MSDTINLTRRGFLKVSGSGVAVAATLSPIASANAQKAPADAGRTTL...,"[[102], [104], [120], [123]]","[0, 0, 0, 0]",False,True
73,COc1cc(C=CC(=O)SCCNC(=O)CCNC(=O)C(O)C(C)(C)COP...,2.3.1.248,6LPW;,O80467;,MPIHIGSSIPLMVEKMLTEMVKPSKHIPQQTLNLSTLDNDPYNEVI...,"[[47], [169], [294], [316], [378], [169], [391]]","[0, 0, 0, 0, 0, 1, 1]",False,True
74,O=C1O[C@H]([C@@H](O)CO)C([O-])=C1O.OO|MTTAVRLL...,1.11.1.5,1A2F;1A2G;1AA4;1AC4;1AC8;1AEB;1AED;1AEE;1AEF;1...,P00431;,MTTAVRLLPSLGRTAHKRSLYLFSAAAAAAAAATFAYSQSQKRSSS...,"[[242], [119], [258], [115]]","[0, 1, 1, 2]",False,True
75,COC(=O)c1cc(O)cc(OC)c1C(=O)c1c([O-])cc(C)cc1O....,1.11.1.-,6KMM;6KMN;7DLK;,P39597;,MSDEQKKPEQIHRRDILKWGAMAGAAVAIGASGLGGLAPLVQTAAK...,"[[241, 243], [326], [339], [240]]","[0, 0, 0, 1]",False,True


In [9]:
from biotite.structure.io.pdb import PDBFile, get_structure
# from biotite.database import rcsb
from rdkit import Chem
import numpy as np
import pandas as pd
from scipy.spatial.distance import squareform, pdist, cdist
from collections import defaultdict
from Bio.SeqUtils import IUPACData
from Bio import pairwise2

def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])

def ligand2protein_contacts_from_pdb(
    structure,
    distance_threshold=8.0,
    chain=None,
    output_dataframe=True,
    ):
    
    protein_mask = ~structure.hetero
    ligand_mask = structure.hetero & (structure.res_name != 'HOH')
    if chain is not None:
        protein_mask &= structure.chain_id == chain
        ligand_mask &= structure.chain_id == chain
    
    N = structure.coord[protein_mask & (structure.atom_name == "N")]
    CA = structure.coord[protein_mask & (structure.atom_name == "CA")]
    C = structure.coord[protein_mask & (structure.atom_name == "C")]
    if N.shape[0] == CA.shape[0] + 1:
        N = N[:-1]
    
    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    
    ligand_atom_coord = structure.coord[ligand_mask]
    
    coords = np.concatenate([ligand_atom_coord, Cbeta], axis=0)

    # type_masks = np.array(['X']*coords.shape[0], dtype=structure.chain_id.dtype)
    # type_masks[:ligand_atom_coord.shape[0]] = 'L'
    # type_masks[ligand_atom_coord.shape[0]:] = 'P'
    
    
    dist = squareform(pdist(coords))
    contacts = dist < distance_threshold
    contacts = contacts.astype(np.int64)
    contacts[np.isnan(dist)] = -1
    
    ligand2protein_contacts = contacts[:ligand_atom_coord.shape[0], ligand_atom_coord.shape[0]:]
    if output_dataframe:
        columns = []
        sequence = ''
        for res_name, res_id in zip(structure.res_name[protein_mask & (structure.atom_name == "CA")], structure.res_id[protein_mask & (structure.atom_name == "N")]):
            res_name_1 = IUPACData.protein_letters_3to1[res_name.capitalize()]
            columns.append(f'{res_name_1}-{res_id}')
            sequence += res_name_1
        index = []
        for atom_name, res_name in zip(structure.atom_name[ligand_mask], structure.res_name[ligand_mask]):
            index.append(f'{res_name}-{atom_name}')
        
        try:
            ligand2protein_contacts = pd.DataFrame(data=ligand2protein_contacts, columns=columns, index=index)
        
        except:
            return None, None

    
    
    return ligand2protein_contacts, sequence

def get_ligand(structure, chain=None):
    ligand_mask = structure.hetero & (structure.res_name != 'HOH')
    if chain is not None:
        ligand_mask &= structure.chain_id == chain
    
    return structure[ligand_mask]

def get_protein(structure, chain=None):
    protein_mask = ~structure.hetero
    if chain is not None:
        protein_mask &= structure.chain_id == chain
    
    return structure[protein_mask]

def get_canonical_order_contacts(mol, contacts_df):
    canonical_smiles = Chem.MolToSmiles(mol)
    canonical_order_mol = Chem.MolFromSmiles(canonical_smiles)
    atom_mapping = {i: idx for i, idx in enumerate(canonical_order_mol.GetSubstructMatch(mol))}
    a = list(atom_mapping.items())
    a.sort(key=lambda x:x[1])
    canonical_order = [x[0] for x in a]
    return contacts_df.reset_index().reindex(canonical_order).reset_index(drop=True), canonical_smiles


def check_reactants_in_pdb_and_get_id(rxn, pdb_ids, pdb_files_folder):
    rxn_smiles = get_rxn_smiles(rxn)
    reactants = rxn_smiles.split('>>')[0].split('.')
    reactants = [canonicalize_smiles(x) for x in reactants]
    if '' in reactants:
        return False
    pdb_id_list = [x for x in pdb_ids.split(';') if x]
    
    ligands = []
    contact_ligand_pdb_ids = []
    pdb_id2contact_ligand = defaultdict(list)
    for pdb_id in pdb_id_list:
        # ligand_smiles = extract_ligand_with_pymol(os.path.join(pdb_files_folder, '{}.pdb'.format(pdb_id.lower())))
        pdb_file = os.path.join(pdb_files_folder, '{}.pdb'.format(pdb_id.lower()))
        if not os.path.exists(pdb_file): continue
        ligand_smiles = extract_ligand_with_rdkit(pdb_file)
        if ligand_smiles:
            for ligand in ligand_smiles.split('.'):
                if (ligand not in ligands) and (ligand in reactants):
                    ligands.append(ligand)
                    pdb_id2contact_ligand[pdb_id].append(ligand)
                    if pdb_id not in contact_ligand_pdb_ids:
                        contact_ligand_pdb_ids.append(pdb_id)
                        
    return ligands, contact_ligand_pdb_ids, pdb_id2contact_ligand

def get_atom_numbers(smi):
    mol = Chem.MolFromSmiles(smi)
    return mol.GetNumAtoms()

def get_atom_numbers_from_ligand(canonical_ligand_smiles):
    atom_numbers = []
    splited_ligand = canonical_ligand_smiles.split('.')
    for smi in splited_ligand:
        mol = Chem.MolFromSmiles(smi)
        atom_numbers.append(mol.GetNumAtoms())
    return atom_numbers, splited_ligand

def mapping_contacts(seqA, seqB, contactA_df):
    seqA_len = len(seqA)
    contactA = contactA_df.values
    assert seqA_len == contactA.shape[1]
    alignments = pairwise2.align.localxx(seqB, seqA)
    if alignments:
        alignment = alignments[0]
        padded_seqA = alignment[1]
        padded_seqB = alignment[0]
        padded_seqB_arr = np.asanyarray(
                [x for x in padded_seqB])
        padded_contactA = np.zeros((contactA_df.shape[0], len(padded_seqA)), dtype=contactA.dtype)
        org_idx = 0
        for padded_idx, aa in enumerate(padded_seqA):
            if aa != '-':
                padded_contactA[:, padded_idx] = contactA[:, org_idx]
                org_idx += 1
            else:
                pass
        contactB = padded_contactA[:, padded_seqB_arr != '-']
        columns = [f'{aa}-{i+1}' for i, aa in enumerate(list(seqB))]
        index = contactA_df.index.tolist()
        
        contactB_df = pd.DataFrame(contactB, columns=columns, index=index)
        
    else:
        return None
    return contactB_df

    
def get_ligand_protein_contacts(pdb_ids, reactants, aa_sequence, pdb_files_folder, contacts_save_folder):
    for pdb_id in pdb_ids:
        structure = get_structure(PDBFile.read(os.path.join(pdb_files_folder, '{}.pdb'.format(pdb_id.lower()))))[0]
        for chain in np.unique(structure.chain_id):
            ligand2protein_contacts, this_sequence = ligand2protein_contacts_from_pdb(structure, chain=chain)

            if isinstance(ligand2protein_contacts, pd.DataFrame) and ligand2protein_contacts.index.tolist() and ligand2protein_contacts.columns.tolist():
                ligand2protein_contacts_for_aa_sequence = mapping_contacts(this_sequence, aa_sequence, ligand2protein_contacts)
                ligand_structure = get_ligand(structure, chain=chain)
                ligand_pdb_file = PDBFile()
                ligand_pdb_file.set_structure(ligand_structure)
                this_pdb_save_folder = os.path.join(contacts_save_folder, pdb_id)
                os.makedirs(this_pdb_save_folder, exist_ok=True)
                ligand_pdb_file.write(os.path.join(this_pdb_save_folder, 'chain_{}_ligand.pdb'.format(chain)))
                
                protein_structure = get_protein(structure, chain=chain)
                protein_pdb_file = PDBFile()
                protein_pdb_file.set_structure(protein_structure)
                protein_pdb_file.write(os.path.join(this_pdb_save_folder, 'chain_{}_protein.pdb'.format(chain)))



                mol = Chem.MolFromPDBBlock('\n'.join(ligand_pdb_file.lines))
                canonical_ligand2protein_contacts, ligand_canonical_smiles = get_canonical_order_contacts(mol, contacts_df=ligand2protein_contacts)
                canonical_ligand2protein_contacts_for_aa_sequence, _ = get_canonical_order_contacts(mol, contacts_df=ligand2protein_contacts_for_aa_sequence)
                
                
                atom_numbers, ligand_canonical_smiles_splited = get_atom_numbers_from_ligand(ligand_canonical_smiles)
                cum_atom_numbers = np.cumsum(atom_numbers).tolist()
                slices = [slice(cum_atom_numbers[i-1] if i > 0 else 0, cum_atom_numbers[i]) for i in range(len(atom_numbers))]
                slices = [slice(sl.start, sl.stop-1) for sl in slices]
                for i, (sl, smi) in enumerate(zip(slices, ligand_canonical_smiles_splited)):
                    if smi in reactants:
                        canonical_reactant2protein_contacts = canonical_ligand2protein_contacts.loc[sl, :]
                        canonical_reactant2protein_contacts.loc[canonical_reactant2protein_contacts.index.max()+1, :] = ['ligand_canonical_smiles', smi] + ['']* (canonical_reactant2protein_contacts.shape[-1] - 2)
                        canonical_reactant2protein_contacts.loc[canonical_reactant2protein_contacts.index.max()+1, :] = ['sequence', this_sequence] + ['']* (canonical_reactant2protein_contacts.shape[-1] - 2)
                        
                        canonical_reactant2protein_contacts.to_csv(os.path.join(this_pdb_save_folder, 'chain_{}_ligand_contacts_react_{}.csv'.format(chain, i)), index=False)
                        
                        canonical_reactant2protein_contacts_for_aa_sequence = canonical_ligand2protein_contacts_for_aa_sequence.loc[sl, :]
                        canonical_reactant2protein_contacts_for_aa_sequence.loc[canonical_reactant2protein_contacts_for_aa_sequence.index.max()+1, :] = ['ligand_canonical_smiles', smi] + ['']* (canonical_reactant2protein_contacts_for_aa_sequence.shape[-1] - 2)
                        canonical_reactant2protein_contacts_for_aa_sequence.loc[canonical_reactant2protein_contacts_for_aa_sequence.index.max()+1, :] = ['sequence', aa_sequence] + ['']* (canonical_reactant2protein_contacts_for_aa_sequence.shape[-1] - 2)
                        canonical_reactant2protein_contacts_for_aa_sequence.to_csv(os.path.join(this_pdb_save_folder, 'chain_{}_ligand_contacts_for_aa_sequence_react_{}.csv'.format(chain, i)), index=False)
                
                
                canonical_ligand2protein_contacts.loc[canonical_ligand2protein_contacts.index.max()+1, :] = ['ligand_canonical_smiles', ligand_canonical_smiles] + ['']* (canonical_ligand2protein_contacts.shape[-1] - 2)
                canonical_ligand2protein_contacts.loc[canonical_ligand2protein_contacts.index.max()+1, :] = ['sequence', this_sequence] + ['']* (canonical_ligand2protein_contacts.shape[-1] - 2)
                
                canonical_ligand2protein_contacts.to_csv(os.path.join(this_pdb_save_folder, 'chain_{}_ligand_contacts.csv'.format(chain)), index=False)
                
                # canonical_ligand2protein_contacts.to_csv('./{}/chain_{}_ligand_contacts.csv'.format(pdb_id, chain))
    



In [10]:
have_reactants_in_pdb_dataset_train_val['ligands'], have_reactants_in_pdb_dataset_train_val['contact_ligand_pdb_ids'], have_reactants_in_pdb_dataset_train_val['pdb_id2contact_ligand'] = map(list, zip(*have_reactants_in_pdb_dataset_train_val.apply(lambda row:check_reactants_in_pdb_and_get_id(row['reaction'], row['pdb-id'], os.path.join(ec_site_dataset_path, 'structures/pdb_download')), axis=1)))
have_reactants_in_pdb_dataset_train_val.to_csv(os.path.join(ec_site_dataset_path, 'uniprot_ecreact_cluster_split_merge_dataset_limit_{}'.format(max_sample), 'have_reactants_in_pdb_dataset_train_val.csv'), index=False)
have_reactants_in_pdb_dataset_train_val


Unnamed: 0,reaction,ec,pdb-id,alphafolddb-id,aa_sequence,site_labels,site_types,all_reactants_in_pdb,have_reactants_in_pdb,ligands,contact_ligand_pdb_ids,pdb_id2contact_ligand
0,CC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@@H]3[C@@H]2CC[...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True,[O],[1YAH],{'1YAH': ['O']}
1,CC(=O)O[C@H]1CC[C@]2(C)C(=CC[C@@H]3[C@@H]2CC[C...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True,[O],[1YAH],{'1YAH': ['O']}
2,CC(=O)Oc1ccc([N+](=O)[O-])cc1.O|MWLRAFILATLSAS...,3.1.1.13,1MX1;1MX5;1MX9;1YA4;1YA8;1YAH;1YAJ;2DQY;2DQZ;2...,P23141;,MWLRAFILATLSASAAWGHPSSPPVVDTVHGKVLGKFVSLEGFAQP...,"[[221], [354], [468]]","[1, 1, 1]",False,True,[O],[1YAH],{'1YAH': ['O']}
3,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True,[OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@H...,[3AXH],{'3AXH': ['OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@...
4,O.OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@...,3.2.1.10,3A47;3A4A;3AJ7;3AXH;3AXI;,P53051;,MTISSAHPETEPKWWKEATFYQIYPASFKDSNDDGWGDMKGIASKL...,"[[215], [277], [352]]","[1, 1, 2]",False,True,[OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@H](O)[C@@H...,[3AXH],{'3AXH': ['OC[C@H]1O[C@H](OC[C@H]2O[C@H](O)[C@...
...,...,...,...,...,...,...,...,...,...,...,...,...
72,O.[Cu+2].[O-][As](O)O|MSDTINLTRRGFLKVSGSGVAVAA...,1.20.9.1,1G8J;1G8K;,Q7SIF3;,MSDTINLTRRGFLKVSGSGVAVAATLSPIASANAQKAPADAGRTTL...,"[[102], [104], [120], [123]]","[0, 0, 0, 0]",False,True,[O],[1G8J],{'1G8J': ['O']}
73,COc1cc(C=CC(=O)SCCNC(=O)CCNC(=O)C(O)C(C)(C)COP...,2.3.1.248,6LPW;,O80467;,MPIHIGSSIPLMVEKMLTEMVKPSKHIPQQTLNLSTLDNDPYNEVI...,"[[47], [169], [294], [316], [378], [169], [391]]","[0, 0, 0, 0, 0, 1, 1]",False,True,[NCCCCNCCCN],[6LPW],{'6LPW': ['NCCCCNCCCN']}
74,O=C1O[C@H]([C@@H](O)CO)C([O-])=C1O.OO|MTTAVRLL...,1.11.1.5,1A2F;1A2G;1AA4;1AC4;1AC8;1AEB;1AED;1AEE;1AEF;1...,P00431;,MTTAVRLLPSLGRTAHKRSLYLFSAAAAAAAAATFAYSQSQKRSSS...,"[[242], [119], [258], [115]]","[0, 1, 1, 2]",False,True,[OO],[1DCC],{'1DCC': ['OO']}
75,COC(=O)c1cc(O)cc(OC)c1C(=O)c1c([O-])cc(C)cc1O....,1.11.1.-,6KMM;6KMN;7DLK;,P39597;,MSDEQKKPEQIHRRDILKWGAMAGAAVAIGASGLGGLAPLVQTAAK...,"[[241, 243], [326], [339], [240]]","[0, 0, 0, 1]",False,True,[OO],[6KMM],{'6KMM': ['OO']}


In [11]:
contacts_save_folder = os.path.join(ec_site_dataset_path, 'uniprot_ecreact_cluster_split_merge_dataset_limit_{}'.format(max_sample), 'ligand_contacts')
os.makedirs(contacts_save_folder, exist_ok=True)

have_reactants_in_pdb_dataset_train_val.apply(lambda row: get_ligand_protein_contacts(row['contact_ligand_pdb_ids'], get_reactants(row['reaction']), aa_sequence=row['aa_sequence'] ,pdb_files_folder=os.path.join(ec_site_dataset_path, 'structures/pdb_download'), contacts_save_folder=contacts_save_folder), axis=1)

0     None
1     None
2     None
3     None
4     None
      ... 
72    None
73    None
74    None
75    None
76    None
Length: 77, dtype: object