In [1]:
# %%

import pandas as pd
import os
from tqdm.auto import tqdm
from pandarallel import pandarallel
from rdkit import Chem
from tqdm import tqdm as top_tqdm
import shutil


In [2]:
# %%

def get_structure_sequence(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file)
        protein_sequence = Chem.MolToSequence(mol)
    except:
        protein_sequence = ''
    return protein_sequence

def multiprocess_structure_check(df, nb_workers, pdb_file_path):
    
    if nb_workers != 0:

        pandarallel.initialize(nb_workers=nb_workers, progress_bar=True)
        df['pdb_files'] = df['alphafolddb-id'].parallel_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].parallel_apply(
            lambda x: get_structure_sequence(x))
    else:
        top_tqdm.pandas(desc='pandas bar')
        df['pdb_files'] = df['alphafolddb-id'].progress_apply(
            lambda x: os.path.join(pdb_file_path, f'AF-{x}-F1-model_v4.pdb'))
        df['aa_sequence_calculated'] = df['pdb_files'].progress_apply(
            lambda x: get_structure_sequence(x))
    
    df['is_valid'] = (df['aa_sequence_calculated'] == df['aa_sequence'])

    return df
def get_blast_database(dir, fasta_path):
    database_df = pd.DataFrame()
    csv_fnames = os.listdir(dir)
    pbar = tqdm(
        csv_fnames,
        total=len(csv_fnames)
    )
    for fname in pbar:
        df = pd.read_csv(os.path.join(dir, fname))
        df = df[['alphafolddb-id', 'aa_sequence', 'site_labels', 'site_types']]
        database_df = pd.concat([database_df, df])
    
    database_df = database_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']).reset_index(drop=True)
    database_df['alphafolddb-id'] = database_df['alphafolddb-id'].apply(lambda x:x.replace(';',''))

    with open(fasta_path, 'w', encoding='utf-8') as f:
        for idx, row in tqdm(database_df.iterrows(), total=len(database_df)):
            f.write('>{}\n'.format(row['alphafolddb-id']))
            f.write('{}\n'.format(row['aa_sequence']))
    return database_df

def get_query_database(path, pdb_file_path, copy_pdb_file_path=None):
    database_df = pd.read_csv(path)
    database_df = database_df[['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']]
    database_df['alphafolddb-id'] = database_df['alphafolddb-id'].apply(lambda x:x.replace(';',''))
    
    database_df = multiprocess_structure_check(database_df, nb_workers=12, pdb_file_path=pdb_file_path)
    
    write_database_df = database_df.drop_duplicates(subset=['alphafolddb-id', 'aa_sequence','site_labels', 'site_types']).reset_index(drop=True)

    os.makedirs(copy_pdb_file_path, exist_ok=True)
    for pdb_id in write_database_df['alphafolddb-id'].tolist():
        src_path = os.path.join(pdb_file_path, f'AF-{pdb_id}-F1-model_v4.pdb')
        tgt_path = os.path.join(copy_pdb_file_path, f'AF-{pdb_id}-F1-model_v4.pdb')

        shutil.copyfile(src_path, tgt_path)

    return database_df




In [3]:
# %%

dataset_path = '../../dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100'
test_dataset_fasta_path = os.path.join(dataset_path, 'test_dataset.fasta')
copy_pdb_file_path=os.path.join('deepfri_workspace', 'test_dataset_pdb')
baseline_results_path = 'baseline_results'

train_database_df = pd.read_pickle('../../dataset/raw_dataset/ec_datasets/split_ec_dataset/train_ec_uniprot_dataset_cluster_sample.pkl')
test_dataset = get_query_database(os.path.join(dataset_path, 'test_dataset', 'uniprot_ecreact_merge.csv'), pdb_file_path=os.path.join(os.path.dirname(dataset_path), 'structures', 'alphafolddb_download'), copy_pdb_file_path=copy_pdb_file_path)
test_dataset = test_dataset.loc[test_dataset['is_valid']]
test_dataset



INFO: Pandarallel will run on 12 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

Unnamed: 0,alphafolddb-id,aa_sequence,site_labels,site_types,pdb_files,aa_sequence_calculated,is_valid
0,A0A1S7LCW6,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,"[[206], [212], [215], [216], [246], [252], [25...","[0, 0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,True
1,Q9F0J6,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,"[[79], [81], [83], [146], [165], [165], [226]]","[0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,True
2,Q5BEJ7,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,"[[41, 42], [245, 247], [320], [330, 334]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,True
3,Q9HUH4,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,"[[17], [36], [44, 45], [49, 51], [346, 347]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,True
4,P96692,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,"[[11, 13], [68, 70], [157, 158], [193], [196]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,True
...,...,...,...,...,...,...,...
889,O30144,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,"[[31, 38]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,True
890,P28737,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,"[[133, 140]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,True
891,P37093,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,"[[397], [400], [430], [433]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,True
892,P94360,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,"[[37, 44]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,True


In [4]:
# %%

len(set(test_dataset['alphafolddb-id']))



851

In [5]:
# %%

import subprocess

deepfri_model_root = '/home/xiaoruiwang/data/ubuntu_work_beta/protein_work/DeepFRI'
args = [
                "--pdb_dir",
                os.path.abspath(copy_pdb_file_path),
                "-ont",
                "ec",
                "-v",
                "--saliency",
                "--use_guided_grads",
                "-o", os.path.abspath(os.path.join(baseline_results_path, 'DeepFRI_PDB')),
                "--model_config", os.path.abspath(os.path.join(deepfri_model_root, 'trained_models/model_config.json'))

        ]

deepfri_results_fname = 'DeepFRI_PDB_EC_saliency_maps.json'

# activate_cmd = f"source activate deepfri_env"
python_cmd = '/home/xiaoruiwang/software/miniconda3/envs/deepfri_env/bin/python {} {}'.format(os.path.join(deepfri_model_root, 'predict.py'), ' '.join(args))

# command = f'{activate_cmd} && {python_cmd}'
command = f'{python_cmd}'
if not os.path.exists(os.path.join(baseline_results_path, deepfri_results_fname)):
    subprocess.run(command, shell=True, check=True, cwd=deepfri_model_root)
    print('')



In [6]:
# %%

import json
import numpy as np
with open(os.path.join(baseline_results_path, deepfri_results_fname), 'r') as f:
    deepfri_results = json.load(f)



In [7]:
# %%

deepfri_results_df = pd.DataFrame(deepfri_results).T
deepfri_results_df['alphafolddb-id'] = deepfri_results_df.index
deepfri_results_df['alphafolddb-id'] = deepfri_results_df['alphafolddb-id'].apply(lambda x:x.split('-')[1])
deepfri_results_df.index = [i for i in range(len(deepfri_results_df))]
deepfri_results_df.columns = ['EC', 'EC Numbers', 'aa_sequence', 'predict_active_prob', 'alphafolddb-id']   # saliency_maps >> predict_active_prob
# deepfri_results_df['predict_active_prob'] = deepfri_results_df['predict_active_prob'].apply(lambda x:np.array(x).reshape(-1).tolist())
deepfri_results_df



Unnamed: 0,EC,EC Numbers,aa_sequence,predict_active_prob,alphafolddb-id
0,"[3.4.13.-, 3.5.4.-, 2.4.1.-, 4.1.1.-]","[3.4.13.-, 3.5.4.-, 2.4.1.-, 4.1.1.-]",MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,"[[0.09321045875549316, 0.1256961077451706, 0.0...",Q9F0J6
1,[3.4.13.-],[3.4.13.-],MQGEIIAGFLAPHPPHLVYGENPPQNEPRSQGGWEVLRWAYERARE...,"[[0.1899009793996811, 0.3382459878921509, 0.29...",Q6J1Z6
2,"[3.4.13.-, 2.7.7.-, 2.7.1.-]","[3.4.13.-, 2.7.7.-, 2.7.1.-]",MITMTNWESLYEKALDKVEASIRKVRGVLLAYNTNIDAIKYLKRED...,"[[0.09452322870492935, 0.13817289471626282, 0....",O58328
3,"[3.4.13.-, 3.4.11.-, 3.4.13.9]","[3.4.13.-, 3.4.11.-, 3.4.13.9]",MSMKSQFERAKIEYGQWGIDVEEALERLKQVPISIHCWQGDDVGGF...,"[[0.005185815040022135, 0.005185815040022135, ...",Q9KCL9
4,"[3.4.13.-, 2.7.1.-, 3.5.1.-, 3.5.3.-]","[3.4.13.-, 2.7.1.-, 3.5.1.-, 3.5.3.-]",MKKLINDVQDVLDEQLAGLAKAHPSLTLHQDPVYVTRADAPVAGKV...,"[[0.0, 0.009247974492609501, 0.047250859439373...",P76015
...,...,...,...,...,...
756,"[6.1.1.12, 6.1.1.23, 6.1.1.-]","[6.1.1.12, 6.1.1.23, 6.1.1.-]",MYPKKTHWTAEITPNLHGTEVVVAGWVWELRDIGRVKFVVVRDREG...,"[[0.12965787947177887, 0.15144096314907074, 0....",Q8ZYM8
757,"[6.1.1.12, 6.1.1.-, 6.1.1.6]","[6.1.1.12, 6.1.1.-, 6.1.1.6]",MSQDENIVKAVEESAEPAQVILGEDGKPLSKKALKKLQKEQEKQRK...,"[[0.04784183204174042, 0.07571902126073837, 0....",P04802
758,"[6.1.1.-, 6.1.1.17]","[6.1.1.-, 6.1.1.17]",MTVRVRLAPSPTGNLHIGTARTAVFNWLYARHRGGKFILRIEDTDR...,"[[0.06084522232413292, 0.20098303258419037, 0....",Q8DLI5
759,"[1.16.3.-, 1.16.3.1]","[1.16.3.-, 1.16.3.1]",MQGDPEVIEFLNEQLTAELTAINQYFLHAKLQDHKGWTKLAKYTRA...,"[[0.11288239806890488, 0.08461116999387741, 0....",Q9S2N0


In [8]:
# %%

from utils import get_active_site_binary, calculate_score
def predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, scoring=True, threshold=0.5, output_results=False, evaluate_threshold=False, reprot=True, evaluate_col_name='overlap_score'):

    predicted_activate_sites = []
    predicted_activate_sites_vec = []
    accuracy_list = []
    precision_list = []
    specificity_list = []
    overlap_scores_list = []
    false_positive_rates_list = []
    f1_scores_list = []
    mcc_scores_list = []
    prediction_succ = []

    pbar = tqdm(test_dataset.iterrows(), total=len(test_dataset), disable=True if not reprot else False)
    for i, row in pbar:
        sequence_id = row['alphafolddb-id']
        aa_sequence = row['aa_sequence']
        active_site_gt = eval(row['site_labels'])
        active_site_type_gt = eval(row['site_types'])
        active_site_gt_bin = get_active_site_binary(active_site_gt,
                                                    len(aa_sequence),
                                                    begain_zero=False)
        active_site_gt = set(
            np.argwhere(active_site_gt_bin == 1).reshape(-1).tolist())

        deepfri_results_for_one:pd.DataFrame = deepfri_results_df.loc[deepfri_results_df['alphafolddb-id']==sequence_id]


        predicted_results = []
        if deepfri_results_for_one.empty:
            # 当没有给出预测值时，所有位点都被设置为阴性
            predicted_active_site_bin = np.zeros((len(aa_sequence), ))
            predicted_activate_sites_vec.append(predicted_active_site_bin.tolist())
            prediction_succ.append(False)
            predicted_activate_site = set(np.argwhere(
                    predicted_active_site_bin != 0).reshape(-1).tolist())
            predicted_results.append(predicted_activate_site)
        else:

            predicted_active_probs = np.array(deepfri_results_for_one['predict_active_prob'].tolist()[0])
            if predicted_active_probs.shape[0] != 1:
                # print()
                pass
            for predicted_active_prob in predicted_active_probs:
                predicted_active_site_bin = (predicted_active_prob >= threshold).astype(float)
                predicted_active_site = set(
                np.argwhere(
                    predicted_active_site_bin == 1).reshape(-1).tolist())
                predicted_results.append(predicted_active_site)

            prediction_succ.append(True)
        merge_predicted_results = set()
        for pred in predicted_results:
            merge_predicted_results.update(pred)
        
        predicted_activate_sites.append(merge_predicted_results)
        # predicted_activate_sites_vec.append(predicted_active_site_bin)
        if scoring:
            acc, prec, spec, overlap_score, fpr, f1, mcc = calculate_score(
                merge_predicted_results, active_site_gt, len(aa_sequence))
            accuracy_list.append(acc)
            precision_list.append(prec)
            specificity_list.append(spec)
            overlap_scores_list.append(overlap_score)
            false_positive_rates_list.append(fpr)
            f1_scores_list.append(f1)
            mcc_scores_list.append(mcc)
            pbar.set_description(
                'Accuracy: {:.4f}, Precision: {:.4f}, Specificity: {:.4f}, Overlap Score: {:.4f}, False Positive Rate: {:.4f}, F1: {:.4f}, MCC: {:.4f}'
                .format(
                    sum(accuracy_list) / len(accuracy_list),
                    sum(precision_list) / len(precision_list),
                    sum(specificity_list) / len(specificity_list),
                    sum(overlap_scores_list) / len(overlap_scores_list),
                    sum(false_positive_rates_list) /
                    len(false_positive_rates_list),
                    sum(f1_scores_list) / len(f1_scores_list),
                    sum(mcc_scores_list) / len(mcc_scores_list),
                    ))
    

    if scoring:

        if reprot:
            print(f'Get {len(overlap_scores_list)} results')
            print(
                'Accuracy: {:.4f}, Precision: {:.4f}, Specificity: {:.4f}, Overlap Score: {:.4f}, False Positive Rate: {:.4f}, F1: {:.4f}, MCC: {:.4f}'
                .format(
                    sum(accuracy_list) / len(accuracy_list),
                    sum(precision_list) / len(precision_list),
                    sum(specificity_list) / len(specificity_list),
                    sum(overlap_scores_list) / len(overlap_scores_list),
                    sum(false_positive_rates_list) /
                    len(false_positive_rates_list),
                    sum(f1_scores_list) / len(f1_scores_list),
                    sum(mcc_scores_list) / len(mcc_scores_list),
                    ))
        

    if output_results:
        test_dataset['predict_active_label'] = predicted_activate_sites
        test_dataset['accuracy'] = accuracy_list
        test_dataset['precision'] = precision_list
        test_dataset['specificity'] = specificity_list
        test_dataset['overlap_scores'] = overlap_scores_list
        test_dataset['false_positive_rates'] = false_positive_rates_list
        test_dataset['f1_scores'] = f1_scores_list
        test_dataset['mcc_scores'] = mcc_scores_list
        test_dataset['prediction_succ'] = prediction_succ

        if scoring:
            score_cols = ['accuracy','precision','specificity','overlap_scores','false_positive_rates','f1_scores','mcc_scores']
            succ_prediction_df = test_dataset.loc[test_dataset['prediction_succ']]
            scoring_str = 'Succ Predictions Score: {} results\n'.format(len(succ_prediction_df))
            for score_name in score_cols:
                scoring_str += '{}: {:.4f}, '.format(score_name, succ_prediction_df[score_name].sum()/len(succ_prediction_df))
            if reprot:
                print(scoring_str)
        if evaluate_threshold:
            return succ_prediction_df[evaluate_col_name].sum() / len(succ_prediction_df)
        return test_dataset
    return predicted_activate_sites, overlap_scores_list, false_positive_rates_list





In [9]:
# %%

best_score = 0
best_threshold = 0
for threshold in [0.050, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.3, 0.4, 0.5]:
    score_mean = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=threshold, scoring=True, output_results=True, evaluate_threshold=True, reprot=False, evaluate_col_name='f1_scores')
    if best_score < score_mean:
        best_score = score_mean
        best_threshold = threshold
print()
print('#'*20)
print()
print('Best:')
print(f'Best Threshold: {best_threshold}')
test_dataset_with_results:pd.DataFrame = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=best_threshold, scoring=True, output_results=True)

os.makedirs('baseline_results', exist_ok=True)
test_dataset_with_results.to_csv(os.path.join('baseline_results', 'deepfri_gradcam_pdb_files.csv'), index=False)
test_dataset_with_results.to_json(os.path.join('baseline_results', 'deepfri_gradcam_pdb_files.json'))






####################

Best:
Best Threshold: 0.5


  0%|          | 0/892 [00:00<?, ?it/s]

Get 892 results
Accuracy: 0.7722, Precision: 0.0713, Specificity: 0.7856, Overlap Score: 0.4577, False Positive Rate: 0.2144, F1: 0.1070, MCC: 0.0979
Succ Predictions Score: 798 results
accuracy: 0.7488, precision: 0.0797, specificity: 0.7604, overlap_scores: 0.5116, false_positive_rates: 0.2396, f1_scores: 0.1196, mcc_scores: 0.1094, 


In [10]:
# %%
#  在SwissProt-E-RXN ASA 数据集中重新训练过的
import subprocess

deepfri_model_root = '/home/xiaoruiwang/data/ubuntu_work_beta/protein_work/DeepFRI'
args = [
                "--pdb_dir",
                os.path.abspath(copy_pdb_file_path),
                "-ont",
                "ec_swissprot",
                "-v",
                "--saliency",
                "--use_guided_grads",
                "-o", os.path.abspath(os.path.join(baseline_results_path, 'DeepFRI_EC_SwissProt')),
                "--model_config", os.path.abspath(os.path.join(deepfri_model_root, 'results/model_config.json'))

        ]

deepfri_results_fname = 'DeepFRI_EC_SwissProt_EC_SWISSPROT_saliency_maps.json'

# activate_cmd = f"source activate deepfri_env"
docker_container_name = 'deepfri_rtx3060'
docker_cmd = f'docker exec {docker_container_name} bash -c'
python_cmd = '\"cd {} && python {} {}\"'.format(deepfri_model_root, os.path.join(deepfri_model_root, 'predict.py'), ' '.join(args))

# command = f'{activate_cmd} && {python_cmd}'
command = f'{docker_cmd} {python_cmd}'
if not os.path.exists(os.path.join(baseline_results_path, deepfri_results_fname)):
    subprocess.run(command, shell=True, check=True, cwd=deepfri_model_root)
    print('')



In [11]:
test_dataset = get_query_database(os.path.join(dataset_path, 'test_dataset', 'uniprot_ecreact_merge.csv'), pdb_file_path=os.path.join(os.path.dirname(dataset_path), 'structures', 'alphafolddb_download'), copy_pdb_file_path=copy_pdb_file_path)
test_dataset = test_dataset.loc[test_dataset['is_valid']]
test_dataset

INFO: Pandarallel will run on 12 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=75), Label(value='0 / 75'))), HBox…

Unnamed: 0,alphafolddb-id,aa_sequence,site_labels,site_types,pdb_files,aa_sequence_calculated,is_valid
0,A0A1S7LCW6,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,"[[206], [212], [215], [216], [246], [252], [25...","[0, 0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MKLKGTTIVALGMLVVAIMVLASMIDLPGSDMSATPAPPDTPRGAP...,True
1,Q9F0J6,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,"[[79], [81], [83], [146], [165], [165], [226]]","[0, 0, 0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MQATKIIDGFHLVGAIDWNSRDFHGYTLSPMGTTYNAYLVEDEKTT...,True
2,Q5BEJ7,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,"[[41, 42], [245, 247], [320], [330, 334]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MADHEQEQEPLSIAIIGGGIIGLMTALGLLHRNIGKVTIYERASAW...,True
3,Q9HUH4,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,"[[17], [36], [44, 45], [49, 51], [346, 347]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MPQALSTDILIVGGGIAGLWLNARLRRAGYATVLVESASLGGGQSV...,True
4,P96692,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,"[[11, 13], [68, 70], [157, 158], [193], [196]]","[0, 0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MAEFTHLVNERRSASNFLSGHPITKEDLNEMFELVALAPSAFNLQH...,True
...,...,...,...,...,...,...,...
889,O30144,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,"[[31, 38]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAG...,True
890,P28737,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,"[[133, 140]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MSRKFDLKTITDLSVLVGTGISLYYLVSRLLNDVESGPLSGKSRES...,True
891,P37093,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,"[[397], [400], [430], [433]]","[0, 0, 0, 0]",../../dataset/ec_site_dataset/structures/alpha...,MTEMVISPAERQSIRRLPFSFANRFKLVLDWNEDFSQASIYYLAPL...,True
892,P94360,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,"[[37, 44]]",[0],../../dataset/ec_site_dataset/structures/alpha...,MAELRMEHIYKFYDQKEPAVDDFNLHIADKEFIVFVGPSGCGKSTT...,True


In [12]:

import json
import numpy as np
with open(os.path.join(baseline_results_path, deepfri_results_fname), 'r') as f:
    deepfri_results = json.load(f)


In [15]:
deepfri_results_fname

'DeepFRI_EC_SwissProt_EC_SWISSPROT_saliency_maps.json'

In [13]:

deepfri_results_df = pd.DataFrame(deepfri_results).T
deepfri_results_df['alphafolddb-id'] = deepfri_results_df.index
deepfri_results_df['alphafolddb-id'] = deepfri_results_df['alphafolddb-id'].apply(lambda x:x.split('-')[1])
deepfri_results_df.index = [i for i in range(len(deepfri_results_df))]
deepfri_results_df.columns = ['EC', 'EC Numbers', 'aa_sequence', 'predict_active_prob', 'alphafolddb-id']   # saliency_maps >> predict_active_prob
# deepfri_results_df['predict_active_prob'] = deepfri_results_df['predict_active_prob'].apply(lambda x:np.array(x).reshape(-1).tolist())
deepfri_results_df

Unnamed: 0,EC,EC Numbers,aa_sequence,predict_active_prob,alphafolddb-id
0,"[3.2.1.14, 3.2.1.4, 3.2.1.78]","[3.2.1.14, 3.2.1.4, 3.2.1.78]",MVKLFSFLLLVWVASPAFSSEFLKASGSNFYYGGQKVFLSGVNFAW...,"[[0.006930299568921328, 0.0, 0.004468691535294...",B4XC07
1,"[3.2.1.14, 3.2.1.4]","[3.2.1.14, 3.2.1.4]",MKSYTSLLAVAILCLFGGVNGACTKNAIAQTGFNKDKYFNGDVWYV...,"[[0.00027659436454996467, 0.010176261886954308...",Q94734
2,"[3.2.1.14, 3.2.1.4]","[3.2.1.14, 3.2.1.4]",MRPRPIRLLLTALVGAGLAFAPVSAVAAPTATASASADVGALDGCY...,"[[0.03727957233786583, 0.047191519290208817, 0...",P00733
3,"[3.2.1.14, 3.2.1.4]","[3.2.1.14, 3.2.1.4]",MSSTKLISLIVSITFFLTLQCSMAQTVVKASYWFPASEFPVTDIDS...,"[[0.005285911727696657, 0.005285911727696657, ...",O81862
4,"[3.2.1.14, 3.2.1.4]","[3.2.1.14, 3.2.1.4]",MDGVLWRVRTAALMAALLALAAWALVWASPSVEAQSNPYQRGPNPT...,"[[0.023156287148594856, 0.02454942651093006, 0...",G9BY57
...,...,...,...,...,...
390,[4.3.3.7],[4.3.3.7],MFQGSIVALITPFKEGEVDYEALGNLIEFHVDNGTDAILVCGTTGE...,"[[0.11112800985574722, 0.24381521344184875, 0....",O67216
391,"[1.1.1.193, 3.5.4.26]","[1.1.1.193, 3.5.4.26]",MEEYYMKLALDLAKQGEGQTESNPLVGAVVVKDGQIVGMGAHLKYG...,"[[0.0, 0.08471090346574783, 0.0826701670885086...",P17618
392,[1.7.7.1],[1.7.7.1],MASLPVNKIIPSSTTLLSSSNNNRRRNNSSIRCQKAVSPAAETAAV...,"[[0.0172539371997118, 0.04702288657426834, 0.0...",P05314
393,[6.3.5.5],[6.3.5.5],MTLMVPFKQVDVFTEKPFMGNPVAVINFLEIDENEVSQEELQAIAN...,"[[0.09309086203575134, 0.1859711855649948, 0.3...",P38765


In [14]:
best_score = 0
best_threshold = 0
for threshold in [0.050, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.3, 0.4, 0.5]:
    score_mean = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=threshold, scoring=True, output_results=True, evaluate_threshold=True, reprot=False, evaluate_col_name='f1_scores')
    if best_score < score_mean:
        best_score = score_mean
        best_threshold = threshold
print()
print('#'*20)
print()
print('Best:')
print(f'Best Threshold: {best_threshold}')
test_dataset_with_results:pd.DataFrame = predict_activate_site_with_deepfri(test_dataset, deepfri_results_df, threshold=best_threshold, scoring=True, output_results=True)

os.makedirs('baseline_results', exist_ok=True)
test_dataset_with_results.to_csv(os.path.join('baseline_results', 'deepfri_ec_swissport_gradcam_pdb_files.csv'), index=False)
test_dataset_with_results.to_json(os.path.join('baseline_results', 'deepfri_ec_swissport_gradcam_pdb_files.json'))


####################

Best:
Best Threshold: 0.5


  0%|          | 0/892 [00:00<?, ?it/s]

Get 892 results
Accuracy: 0.7584, Precision: 0.0202, Specificity: 0.7763, Overlap Score: 0.2913, False Positive Rate: 0.2237, F1: 0.0352, MCC: 0.0230
Succ Predictions Score: 418 results
accuracy: 0.5250, precision: 0.0430, specificity: 0.5226, overlap_scores: 0.6215, false_positive_rates: 0.4774, f1_scores: 0.0751, mcc_scores: 0.0490, 
