In [1]:
# %%

import pandas as pd
import os
from tqdm.auto import tqdm
from pandarallel import pandarallel
from rdkit import Chem
from tqdm import tqdm as top_tqdm



In [2]:
# %%

def get_structure_sequence(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file)
        protein_sequence = Chem.MolToSequence(mol)
    except:
        protein_sequence = ''
    return protein_sequence

def multiprocess_structure_check(df, nb_workers, pdb_file_path):
    
    if nb_workers != 0:

        pandarallel.initialize(nb_workers=nb_workers, progress_bar=True)
        df['pdb_files'] = df['alphafolddb-id'].parallel_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].parallel_apply(
            lambda x: get_structure_sequence(x))
    else:
        top_tqdm.pandas(desc='pandas bar')
        df['pdb_files'] = df['alphafolddb-id'].progress_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].progress_apply(
            lambda x: get_structure_sequence(x))
    
    df['is_valid'] = (df['aa_sequence_calculated'] == df['aa_sequence'])

    return df
def get_blast_database(dir, fasta_path):
    database_df = pd.DataFrame()
    csv_fnames = os.listdir(dir)
    pbar = tqdm(
        csv_fnames,
        total=len(csv_fnames)
    )
    for fname in pbar:
        df = pd.read_csv(os.path.join(dir, fname))
        df = df[['alphafolddb-id', 'aa_sequence', 'site_labels', 'site_types']]
        database_df = pd.concat([database_df, df])
    
    database_df = database_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']).reset_index(drop=True)
    database_df['alphafolddb-id'] = database_df['alphafolddb-id'].apply(lambda x:x.replace(';',''))

    with open(fasta_path, 'w', encoding='utf-8') as f:
        for idx, row in tqdm(database_df.iterrows(), total=len(database_df)):
            f.write('>{}\n'.format(row['alphafolddb-id']))
            f.write('{}\n'.format(row['aa_sequence']))
    return database_df

def get_query_database(path, fasta_path, pdb_file_path):
    database_df = pd.read_csv(path)
    database_df = database_df[['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']]
    database_df['alphafolddb-id'] = database_df['alphafolddb-id'].apply(lambda x:x.replace(';',''))
    
    database_df = multiprocess_structure_check(database_df, nb_workers=12, pdb_file_path=pdb_file_path)
    
    write_database_df = database_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']).reset_index(drop=True)


    with open(fasta_path, 'w', encoding='utf-8') as f:
        for idx, row in tqdm(write_database_df.iterrows(), total=len(write_database_df)):
            f.write('>{}\n'.format(row['alphafolddb-id']))
            f.write('{}\n'.format(row['aa_sequence']))
    return database_df




In [3]:
# %%

dataset_path = '../../dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100'
test_dataset_fasta_path = os.path.join(dataset_path, 'test_dataset.fasta')
baseline_results_path = 'baseline_results'

train_database_df = pd.read_pickle('../../dataset/raw_dataset/ec_datasets/split_ec_dataset/train_ec_uniprot_dataset_cluster_sample.pkl')
test_dataset = get_query_database(os.path.join(dataset_path, 'test_dataset', 'uniprot_ecreact_merge.csv'), fasta_path=test_dataset_fasta_path, pdb_file_path=os.path.join(os.path.dirname(dataset_path), 'structures', 'alphafolddb_download'))
test_dataset = test_dataset.loc[test_dataset['is_valid']]
test_dataset



INFO: Pandarallel will run on 12 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

  0%|          | 0/853 [00:00<?, ?it/s]

Unnamed: 0,alphafolddb-id,aa_sequence,site_labels,site_types,pdb_files,aa_sequence_calculated,is_valid
0,A0A1S7LCW6,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,"[[206], [212], [215], [216], [246], [252], [25...","[0, 0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,True
1,Q9F0J6,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,"[[79], [81], [83], [146], [165], [165], [226]]","[0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,True
2,Q5BEJ7,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,"[[41, 42], [245, 247], [320], [330, 334]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,True
3,Q9HUH4,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,"[[17], [36], [44, 45], [49, 51], [346, 347]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,True
4,P96692,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,"[[11, 13], [68, 70], [157, 158], [193], [196]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,True
...,...,...,...,...,...,...,...
889,O30144,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,"[[31, 38]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,True
890,P28737,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,"[[133, 140]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,True
891,P37093,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,"[[397], [400], [430], [433]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,True
892,P94360,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,"[[37, 44]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,True


In [4]:
# %%

len(set(test_dataset['alphafolddb-id']))



851

In [5]:
# %%

import subprocess

deepfri_model_root = '/home/xiaoruiwang/data/ubuntu_work_beta/protein_work/DeepFRI'
args = [
                "--fasta_fn",
                os.path.abspath(test_dataset_fasta_path),
                "-ont",
                "ec",
                "-v",
                "--saliency",
                "--use_guided_grads",
                "-o", os.path.abspath(os.path.join(baseline_results_path, 'DeepFRI')),
                "--model_config", os.path.abspath(os.path.join(deepfri_model_root, 'trained_models/model_config.json'))

        ]

deepfri_results_fname = 'DeepFRI_EC_saliency_maps.json'

# activate_cmd = f"source activate deepfri_env"
python_cmd = '/home/xiaoruiwang/software/miniconda3/envs/deepfri_env/bin/python {} {}'.format(os.path.join(deepfri_model_root, 'predict.py'), ' '.join(args))

# command = f'{activate_cmd} && {python_cmd}'
command = f'{python_cmd}'
if not os.path.exists(os.path.join(baseline_results_path, 'DeepFRI_EC_saliency_maps.json')):
        subprocess.run(command, shell=True, check=True, cwd=deepfri_model_root)
        print('')



In [6]:
# %%

import json
import numpy as np
with open(os.path.join(baseline_results_path, deepfri_results_fname), 'r') as f:
    deepfri_results = json.load(f)



In [7]:
# %%

deepfri_results_df = pd.DataFrame(deepfri_results).T
deepfri_results_df['alphafolddb-id'] = deepfri_results_df.index
deepfri_results_df.index = [i for i in range(len(deepfri_results_df))]
deepfri_results_df.columns = ['EC', 'EC Numbers', 'aa_sequence', 'predict_active_prob', 'alphafolddb-id']   # saliency_maps >> predict_active_prob
# deepfri_results_df['predict_active_prob'] = deepfri_results_df['predict_active_prob'].apply(lambda x:np.array(x).reshape(-1).tolist())
deepfri_results_df



Unnamed: 0,EC,EC Numbers,aa_sequence,predict_active_prob,alphafolddb-id
0,[1.1.1.-],[1.1.1.-],MRCVVFNLREEEAPYVEKWKQSHPGVVVDTYEEPLTAKNKELLKGY...,"[[0.004285963252186775, 0.006816012319177389, ...",O83080
1,[1.1.1.-],[1.1.1.-],MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,"[[0.015918593853712082, 0.009973255917429924, ...",Q5BEJ7
2,"[1.1.1.-, 1.1.1.85]","[1.1.1.-, 1.1.1.85]",MGFTVALIQGDGIGPEIVSKSKRILAKINELYSLPIEYIEVEAGDR...,"[[0.008776959963142872, 0.174948588013649, 0.1...",P50455
3,[1.1.1.-],[1.1.1.-],MGRIGILGAGLAGLAAATKLAEAGENVTVFEARNRPGGRVWSETLD...,"[[0.43652114272117615, 0.3544602394104004, 0.3...",Q8GAJ0
4,[1.1.1.-],[1.1.1.-],MSSSDGKLRYDGRVAVVTGAGAGLGREYALLFAERGAKVVVNDLGG...,"[[0.0028535460587590933, 0.02065255306661129, ...",Q9VXJ0
...,...,...,...,...,...
461,"[6.3.4.-, 6.3.4.3]","[6.3.4.-, 6.3.4.3]",MSKVPSDIEIAQAAKMKPVMELARGLGIQEDEVELYGKYKAKISLD...,"[[0.009033123031258583, 0.05017774552106857, 0...",P21164
462,"[6.3.4.-, 6.3.4.2]","[6.3.4.-, 6.3.4.2]",MPNKYIVVTGGVLSSVGKGTLVASIGMLLKRRGYNVTAVKIDPYIN...,"[[0.0, 0.01902197115123272, 0.0555660575628280...",Q980S6
463,"[6.3.4.-, 6.3.4.3]","[6.3.4.-, 6.3.4.3]",MSKVPSDIEIAQAAKMKPVMELARGLGIQEDEVELYGKYKAKISLD...,"[[0.009033121168613434, 0.050177741795778275, ...",Q2RM91
464,[6.5.1.-],[6.5.1.-],MLFAEFAEFCERLEKISSTLELTARIAAFLQKIEDERDLYDVVLFI...,"[[0.230945885181427, 0.3018735945224762, 0.296...",O29632


In [8]:
# %%

from utils import get_active_site_binary, calculate_score
def predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, scoring=True, threshold=0.5, output_results=False, evaluate_threshold=False, reprot=True, evaluate_col_name='overlap_score'):

    predicted_activate_sites = []
    predicted_activate_sites_vec = []
    accuracy_list = []
    precision_list = []
    specificity_list = []
    overlap_scores_list = []
    false_positive_rates_list = []
    f1_scores_list = []
    mcc_scores_list = []
    prediction_succ = []

    pbar = tqdm(test_dataset.iterrows(), total=len(test_dataset), disable=True if not reprot else False)
    for i, row in pbar:
        sequence_id = row['alphafolddb-id']
        aa_sequence = row['aa_sequence']
        active_site_gt = eval(row['site_labels'])
        active_site_type_gt = eval(row['site_types'])
        active_site_gt_bin = get_active_site_binary(active_site_gt,
                                                    len(aa_sequence),
                                                    begain_zero=False)
        active_site_gt = set(
            np.argwhere(active_site_gt_bin == 1).reshape(-1).tolist())

        deepfri_results_for_one:pd.DataFrame = deepfri_results_df.loc[deepfri_results_df['alphafolddb-id']==sequence_id]


        predicted_results = []
        if deepfri_results_for_one.empty:
            # 当没有给出预测值时，所有位点都被设置为阴性
            predicted_active_site_bin = np.zeros((len(aa_sequence), ))
            predicted_activate_sites_vec.append(predicted_active_site_bin.tolist())
            prediction_succ.append(False)
            predicted_activate_site = set(np.argwhere(
                    predicted_active_site_bin != 0).reshape(-1).tolist())
            predicted_results.append(predicted_activate_site)
        else:

            predicted_active_probs = np.array(deepfri_results_for_one['predict_active_prob'].tolist()[0])
            if predicted_active_probs.shape[0] != 1:
                # print()
                pass
            for predicted_active_prob in predicted_active_probs:
                predicted_active_site_bin = (predicted_active_prob >= threshold).astype(float)
                predicted_active_site = set(
                np.argwhere(
                    predicted_active_site_bin == 1).reshape(-1).tolist())
                predicted_results.append(predicted_active_site)

            prediction_succ.append(True)
        merge_predicted_results = set()
        for pred in predicted_results:
            merge_predicted_results.update(pred)
        
        predicted_activate_sites.append(merge_predicted_results)
        # predicted_activate_sites_vec.append(predicted_active_site_bin)
        if scoring:
            acc, prec, spec, overlap_score, fpr, f1, mcc = calculate_score(
                merge_predicted_results, active_site_gt, len(aa_sequence))
            accuracy_list.append(acc)
            precision_list.append(prec)
            specificity_list.append(spec)
            overlap_scores_list.append(overlap_score)
            false_positive_rates_list.append(fpr)
            f1_scores_list.append(f1)
            mcc_scores_list.append(mcc)
            pbar.set_description(
                'Accuracy: {:.4f}, Precision: {:.4f}, Specificity: {:.4f}, Overlap Score: {:.4f}, False Positive Rate: {:.4f}, F1: {:.4f}, MCC: {:.4f}'
                .format(
                    sum(accuracy_list) / len(accuracy_list),
                    sum(precision_list) / len(precision_list),
                    sum(specificity_list) / len(specificity_list),
                    sum(overlap_scores_list) / len(overlap_scores_list),
                    sum(false_positive_rates_list) /
                    len(false_positive_rates_list),
                    sum(f1_scores_list) / len(f1_scores_list),
                    sum(mcc_scores_list) / len(mcc_scores_list),
                    ))
    

    if scoring:

        if reprot:
            print(f'Get {len(overlap_scores_list)} results')
            print(
                'Accuracy: {:.4f}, Precision: {:.4f}, Specificity: {:.4f}, Overlap Score: {:.4f}, False Positive Rate: {:.4f}, F1: {:.4f}, MCC: {:.4f}'
                .format(
                    sum(accuracy_list) / len(accuracy_list),
                    sum(precision_list) / len(precision_list),
                    sum(specificity_list) / len(specificity_list),
                    sum(overlap_scores_list) / len(overlap_scores_list),
                    sum(false_positive_rates_list) /
                    len(false_positive_rates_list),
                    sum(f1_scores_list) / len(f1_scores_list),
                    sum(mcc_scores_list) / len(mcc_scores_list),
                    ))
        

    if output_results:
        test_dataset['predict_active_label'] = predicted_activate_sites
        test_dataset['accuracy'] = accuracy_list
        test_dataset['precision'] = precision_list
        test_dataset['specificity'] = specificity_list
        test_dataset['overlap_scores'] = overlap_scores_list
        test_dataset['false_positive_rates'] = false_positive_rates_list
        test_dataset['f1_scores'] = f1_scores_list
        test_dataset['mcc_scores'] = mcc_scores_list
        test_dataset['prediction_succ'] = prediction_succ

        if scoring:
            score_cols = ['accuracy','precision','specificity','overlap_scores','false_positive_rates','f1_scores','mcc_scores']
            succ_prediction_df = test_dataset.loc[test_dataset['prediction_succ']]
            scoring_str = 'Succ Predictions Score: {} results\n'.format(len(succ_prediction_df))
            for score_name in score_cols:
                scoring_str += '{}: {:.4f}, '.format(score_name, succ_prediction_df[score_name].sum()/len(succ_prediction_df))
            if reprot:
                print(scoring_str)
        if evaluate_threshold:
            return succ_prediction_df[evaluate_col_name].sum() / len(succ_prediction_df)
        return test_dataset
    return predicted_activate_sites, overlap_scores_list, false_positive_rates_list





In [9]:
# %%

best_score = 0
best_threshold = 0
for threshold in [0.050, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.3, 0.4, 0.5]:
    score_mean = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=threshold, scoring=True, output_results=True, evaluate_threshold=True, reprot=False, evaluate_col_name='f1_scores')
    if best_score < score_mean:
        best_score = score_mean
        best_threshold = threshold
print()
print('#'*20)
print()
print('Best:')
print(f'Best threshold: {best_threshold}')
test_dataset_with_results:pd.DataFrame = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=best_threshold, scoring=True, output_results=True)

os.makedirs('baseline_results', exist_ok=True)
test_dataset_with_results.to_csv(os.path.join('baseline_results', 'deepfri_gradcam_aa_sequence.csv'), index=False)
test_dataset_with_results.to_json(os.path.join('baseline_results', 'deepfri_gradcam_aa_sequence.json'))






####################

Best:
Best threshold: 0.175


  0%|          | 0/892 [00:00<?, ?it/s]

Get 892 results
Accuracy: 0.8628, Precision: 0.0357, Specificity: 0.8881, Overlap Score: 0.1368, False Positive Rate: 0.1119, F1: 0.0442, MCC: 0.0187
Succ Predictions Score: 488 results
accuracy: 0.7727, precision: 0.0652, specificity: 0.7955, overlap_scores: 0.2500, false_positive_rates: 0.2045, f1_scores: 0.0807, mcc_scores: 0.0341, 
