In [1]:
import sys
import torch
import os
sys.path.append('../../')
from tqdm.auto import tqdm
from collections import defaultdict
from functools import partial
import py3Dmol
from IPython.display import IFrame, SVG, display, HTML
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.Draw import IPythonConsole
from rdkit import Chem
from pandarallel import pandarallel
from webapp.utils import (
    EasIFAInferenceAPI,
    UniProtParserMysql,
    retrain_ec_site_model_state_path,
    get_structure_html_and_active_data,
    cmd,
)
from data_loaders.rxn_dataloader import process_reaction
from data_loaders.enzyme_rxn_dataloader import get_rxn_smiles
from common.utils import calculate_scores_vbin_test




In [2]:
ECSitePred = EasIFAInferenceAPI(model_checkpoint_path=retrain_ec_site_model_state_path, pred_tolist=False)
unprot_mysql_parser = UniProtParserMysql(
    mysql_config_path="../../webapp/mysql_config.json"
)



Reaction attention model from scratch...
Loaded checkpoint from /home/xiaoruiwang/data/ubuntu_work_beta/single_step_work/EasIFA_v2/checkpoints/enzyme_site_type_predition_model/train_in_uniprot_ecreact_cluster_split_merge_dataset_limit_100_at_2024-05-26-02-48-38/global_step_86000
[Errno 2] No such file or directory: '../../webapp/mysql_config.json'


In [3]:
def calculate_active_sites(site_label, sequence_length):
    site_label = eval(site_label)  # Note: Site label starts from 1
    active_site = torch.zeros((sequence_length, ))
    for one_site in site_label:
        if len(one_site) == 1:
            active_site[one_site[0] - 1] = 1
        elif len(one_site) == 2:
            b, e = one_site
            site_indices = [k - 1 for k in range(b, e+1)]
            # site_indices = [k - 1 for k in range(b, e)]
            active_site[site_indices] = 1
        else:
            raise ValueError(
                'The label of active site is not standard !!!')
    return active_site

def inference_and_scoring(test_dataset:pd.DataFrame, esmfold_pdb_path):
    for idx, row in tqdm(test_dataset.iterrows(), total=len(test_dataset), desc='Testing'):
        uniprot_id = row['alphafolddb-id']
        rxn = row['canonicalize_rxn_smiles']
        site_label = row['site_labels']
        aa_sequence = row['aa_sequence']
        gts = calculate_active_sites(site_label, len(aa_sequence))
        
        enzyme_structure_path = os.path.join(esmfold_pdb_path, f'{uniprot_id}.pdb')
        pred_active_labels = ECSitePred.inference(rxn, enzyme_structure_path)   # 默认输出一个样本的结果

        calculate_scores_vbin_test(pred_active_labels, gts=gts, num_residues=len(aa_sequence))
        


    
    
    

In [4]:






def get_structure_sequence(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file)
        protein_sequence = Chem.MolToSequence(mol)
    except:
        protein_sequence = ''
    return protein_sequence

def multiprocess_structure_check(df, nb_workers, pdb_file_path):
    
    if nb_workers != 0:

        pandarallel.initialize(nb_workers=nb_workers, progress_bar=True)
        df['pdb_files'] = df['alphafolddb-id'].parallel_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].parallel_apply(
            lambda x: get_structure_sequence(x))
    else:
        tqdm.pandas(desc='pandas bar')
        df['pdb_files'] = df['alphafolddb-id'].progress_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].progress_apply(
            lambda x: get_structure_sequence(x))
    
    df['is_valid'] = (df['aa_sequence_calculated'] == df['aa_sequence'])

    return df


def get_query_database(path, fasta_path, pdb_file_path):
    database_df = pd.read_csv(path)
    database_df = database_df[['alphafolddb-id', 'aa_sequence','site_labels', 'site_types', 'reaction']]
    database_df['alphafolddb-id'] = database_df['alphafolddb-id'].apply(lambda x:x.replace(';',''))
    database_df["rxn_smiles"] = database_df["reaction"].apply(
                lambda x: get_rxn_smiles(x)
            )
    database_df['canonicalize_rxn_smiles'] = database_df["rxn_smiles"].apply(
            lambda x: process_reaction(x)
        )
    
    database_df = multiprocess_structure_check(database_df, nb_workers=12, pdb_file_path=pdb_file_path)
    
    write_database_df = database_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']).reset_index(drop=True)


    with open(fasta_path, 'w', encoding='utf-8') as f:
        for idx, row in tqdm(write_database_df.iterrows(), total=len(write_database_df)):
            f.write('>{}\n'.format(row['alphafolddb-id']))
            f.write('{}\n'.format(row['aa_sequence']))
    return database_df

    

           

In [5]:
dataset_path = '../../dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100'
blast_database_path = '../../dataset/raw_dataset/uniprot/uniprot_sprot.fasta'
test_dataset_fasta_path=os.path.join(dataset_path, 'test_dataset.fasta')
esmfold_pdb_path = './esmfold_pdb'
os.makedirs(esmfold_pdb_path, exist_ok=True)

test_dataset = get_query_database(os.path.join(dataset_path, 'test_dataset', 'uniprot_ecreact_merge.csv'), fasta_path=test_dataset_fasta_path, pdb_file_path=os.path.join(os.path.dirname(dataset_path), 'structures', 'alphafolddb_download'))
test_dataset = test_dataset.loc[test_dataset['is_valid']].reset_index(drop=True)
test_dataset


INFO: Pandarallel will run on 12 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

  0%|          | 0/853 [00:00<?, ?it/s]

Unnamed: 0,alphafolddb-id,aa_sequence,site_labels,site_types,reaction,rxn_smiles,canonicalize_rxn_smiles,pdb_files,aa_sequence_calculated,is_valid
0,A0A1S7LCW6,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,"[[206], [212], [215], [216], [246], [252], [25...","[0, 0, 0, 0, 0, 0, 0, 0]",Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[...,Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[...,Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[...,../../dataset/ec_site_dataset/structures/alpha...,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,True
1,Q9F0J6,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,"[[79], [81], [83], [146], [165], [165], [226]]","[0, 0, 0, 0, 0, 0, 0]",NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,../../dataset/ec_site_dataset/structures/alpha...,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,True
2,Q5BEJ7,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,"[[41, 42], [245, 247], [320], [330, 334]]","[0, 0, 0, 0]",NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,../../dataset/ec_site_dataset/structures/alpha...,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,True
3,Q9HUH4,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,"[[17], [36], [44, 45], [49, 51], [346, 347]]","[0, 0, 0, 0, 0]",NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,../../dataset/ec_site_dataset/structures/alpha...,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,True
4,P96692,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,"[[11, 13], [68, 70], [157, 158], [193], [196]]","[0, 0, 0, 0, 0]",NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)([O-])OP(...,../../dataset/ec_site_dataset/structures/alpha...,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,True
...,...,...,...,...,...,...,...,...,...,...
887,O30144,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,"[[31, 38]]",[0],Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)([O-])OP(=O...,Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)([O-])OP(=O...,Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)([O-])OP(=O...,../../dataset/ec_site_dataset/structures/alpha...,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,True
888,P28737,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,"[[133, 140]]",[0],N[C@@H](CSC[C@H](N)C(=O)O)C(=O)O.Nc1ncnc2c1ncn...,N[C@@H](CSC[C@H](N)C(=O)O)C(=O)O.Nc1ncnc2c1ncn...,N[C@@H](CSC[C@H](N)C(=O)O)C(=O)O.Nc1ncnc2c1ncn...,../../dataset/ec_site_dataset/structures/alpha...,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,True
889,P37093,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,"[[397], [400], [430], [433]]","[0, 0, 0, 0]",Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)(O)OP(=O)(O...,Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)(O)OP(=O)(O...,Nc1ncnc2c1ncn2[C@@H]1O[C@H](COP(=O)(O)OP(=O)(O...,../../dataset/ec_site_dataset/structures/alpha...,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,True
890,P94360,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,"[[37, 44]]",[0],CCCCCCCCCCCCCCCCCCCCCCCCC(C(=O)OC[C@H]1O[C@H](...,CCCCCCCCCCCCCCCCCCCCCCCCC(C(=O)OC[C@H]1O[C@H](...,CCCCCCCCCCCCCCCCCCCCCCCCC(C(=O)OC[C@H]1O[C@H](...,../../dataset/ec_site_dataset/structures/alpha...,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,True


In [6]:
import subprocess

esmfold_script = os.path.abspath('../esmfold_inference.py')
test_dataset_fasta_abspath = os.path.abspath(test_dataset_fasta_path)
esmfold_pdb_abspath = os.path.abspath(esmfold_pdb_path)

pdb_fnames = [x for x in os.listdir(esmfold_pdb_abspath) if x.endswith('.pdb')]


esmfold_cmd = f'python {esmfold_script} -i {test_dataset_fasta_abspath} -o {esmfold_pdb_abspath}'

if len(pdb_fnames) != len(set(test_dataset['alphafolddb-id'])):
    subprocess.run(esmfold_cmd, shell=True)


24/05/29 11:26:56 | INFO | root | Reading sequences from /home/xiaoruiwang/data/ubuntu_work_beta/single_step_work/EasIFA_v2/dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100/test_dataset.fasta
24/05/29 11:26:56 | INFO | root | Loaded 853 sequences from /home/xiaoruiwang/data/ubuntu_work_beta/single_step_work/EasIFA_v2/dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100/test_dataset.fasta
24/05/29 11:26:56 | INFO | root | Loading model
24/05/29 11:27:56 | INFO | root | Starting Predictions
24/05/29 11:28:15 | INFO | root | Predicted structure for Q38HX3 with length 86, pLDDT 74.9, pTM 0.614 in 1.9s (amortized, batch size 10). 1 / 853 completed.
24/05/29 11:28:15 | INFO | root | Predicted structure for Q5SKS6 with length 88, pLDDT 90.1, pTM 0.887 in 1.9s (amortized, batch size 10). 2 / 853 completed.
24/05/29 11:28:15 | INFO | root | Predicted structure for Q746F4 with length 90, pLDDT 81.2, pTM 0.786 in 1.9s (amortized, batch size 10

In [None]:
inference_and_scoring(test_dataset=test_dataset, esmfold_pdb_path=esmfold_pdb_abspath)
