In [1]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

In [2]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='lncRNA', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['dataset']=='archiveII', 'dataset'] = 'ArchiveII'
result_algos.loc[result_algos['model']=='UFold', 'model'] = 'UFold (original)'

In [3]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII", "viral_fragments", "lncRNA"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII: Download complete. File saved at data/archiveII/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA: Download complete. File saved at data/lncRNA/data.json


In [4]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()


    TP = torch.sum(pred_matrix*target_matrix)
    PP = torch.sum(pred_matrix)
    P = torch.sum(target_matrix)
    sum_pair = PP + P

    if sum_pair == 0:
        return [1.0, 1.0, 1.0]
    else:
        return [
                (TP / PP).item(),
                (TP / P).item(),
                (2 * TP / sum_pair).item()
                ]
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [5]:
data_comparison = pd.DataFrame()

path_data = '../Figure5/saved_data_plot/results_V2'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model)).drop_duplicates(['reference'])

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    F1s = []
    Precisions = []
    Recalls = []
    for i, row in merged.iterrows():
        Precision, Recall, F1 = compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                                        pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5)
        
        F1s.append(F1)
        Precisions.append(Precision)
        Recalls.append(Recall)

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['F1'] = F1s
    merged['Precision'] = Precisions
    merged['Recall'] = Recalls
    merged['model'] = model
    print(merged.groupby('dataset')['F1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'model', 'dataset', 'F1', 'Precision', 'Recall']]])

data_comparison.loc[data_comparison['model']=='eFoldPT+FT', 'model'] = 'eFold'
data_comparison.loc[data_comparison['model']=='eFoldPT', 'model'] = 'eFold (pre-trained)'
data_comparison.loc[data_comparison['model']=='UFoldPT+FT', 'model'] = 'UFold (re-trained)'

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA', 'dataset'] = 'Long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII', 'dataset'] = 'ArchiveII'

eFoldPT
dataset
PDB                0.886010
archiveII          0.584986
lncRNA             0.400431
viral_fragments    0.674164
Name: F1, dtype: float64
human-mRNA
dataset
PDB                0.890150
archiveII          0.580508
lncRNA             0.402036
viral_fragments    0.682623
Name: F1, dtype: float64
UFoldPT+FT
dataset
PDB                0.890532
archiveII          0.607041
lncRNA             0.347520
viral_fragments    0.674862
Name: F1, dtype: float64
eFoldPT+FT
dataset
PDB                0.894044
archiveII          0.599568
lncRNA             0.442343
viral_fragments    0.725702
Name: F1, dtype: float64
pri-miRNA
dataset
PDB                0.891003
archiveII          0.599262
lncRNA             0.318432
viral_fragments    0.707887
Name: F1, dtype: float64
UFoldPT
dataset
PDB                0.890010
archiveII          0.603470
lncRNA             0.377261
viral_fragments    0.669095
Name: F1, dtype: float64


In [6]:
data_comparison.dataset.unique(), result_algos.dataset.unique()

(array(['PDB', 'ArchiveII', 'Viral mRNA', 'Long ncRNA'], dtype=object),
 array(['PDB', 'ArchiveII', 'Viral mRNA', 'Long ncRNA'], dtype=object))

In [7]:
results_perf = pd.concat([data_comparison, result_algos.drop(columns=['length', 'structure'])])

results_perf = results_perf[results_perf.model.isin(['UFold (re-trained)', 'eFold (pre-trained)', 'eFold', 'SPOT-RNA', 'UFold (original)'])]#.groupby(['model', 'dataset']).mean()
results_perf

Unnamed: 0,reference,model,dataset,F1,Precision,Recall,MCC,structure_type
0,2N7X-2D,eFold (pre-trained),PDB,0.933333,0.875000,1.000000,,
1,8S95-2D,eFold (pre-trained),PDB,0.703704,0.666667,0.745098,,
2,2CD1-2D,eFold (pre-trained),PDB,1.000000,1.000000,1.000000,,
3,1WTT-2D,eFold (pre-trained),PDB,1.000000,1.000000,1.000000,,
4,2NCI-2D,eFold (pre-trained),PDB,0.888889,0.800000,1.000000,,
...,...,...,...,...,...,...,...,...
30250,XIST_10,SPOT-RNA,Long ncRNA,0.363919,0.422383,0.319672,0.367247,<class 'list'>
30251,MALAT1_0,SPOT-RNA,Long ncRNA,0.221729,0.364964,0.159236,0.240774,<class 'list'>
30252,XIST_11,SPOT-RNA,Long ncRNA,0.323917,0.445596,0.254438,0.336500,<class 'list'>
30253,NORAD1_55C,SPOT-RNA,Long ncRNA,0.253552,0.347305,0.199656,0.263151,<class 'list'>


In [8]:
# Group the data by model and dataset and calculate the mean for each group
grouped = results_perf.groupby(['model', 'dataset']).mean(numeric_only=True).reset_index()

# Pivot the table to create a multi-level column structure
pivot_df = pd.pivot_table(grouped, index='model', columns='dataset', values=['Precision', 'Recall', 'F1'])

# Swap the level of the columns to have dataset as the top level and the metrics as the second level
pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)

# Define the new order for the models and reorder the rows
new_order = ['SPOT-RNA', 'UFold (original)', 'UFold (re-trained)', 'eFold (pre-trained)', 'eFold']
pivot_df = pivot_df.reindex(new_order)

pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(['Precision', 'Recall', 'F1'], level=1)[0])[['PDB', 'ArchiveII', 'Viral mRNA', 'Long ncRNA']]

pivot_df = pivot_df.style\
            .format(precision=3)\
            .highlight_max(axis=0, props="font-weight:bold;font-color:black;")\
            .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap="viridis", text_color_threshold=0)\
            .set_properties(**{'text-align': 'center'})\
            .set_table_styles(
                        [{"selector": "th", "props": [('text-align', 'center')]},
                        ])
pivot_df


dataset,PDB,PDB,PDB,ArchiveII,ArchiveII,ArchiveII,Viral mRNA,Viral mRNA,Viral mRNA,Long ncRNA,Long ncRNA,Long ncRNA
Unnamed: 0_level_1,Precision,Recall,F1,Precision,Recall,F1,Precision,Recall,F1,Precision,Recall,F1
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
SPOT-RNA,0.854,0.925,0.873,0.694,0.727,0.695,0.683,0.496,0.565,0.339,0.214,0.259
UFold (original),0.81,0.965,0.87,0.825,0.887,0.852,0.578,0.587,0.579,0.219,0.144,0.163
UFold (re-trained),0.898,0.916,0.891,0.588,0.636,0.607,0.666,0.699,0.675,0.555,0.276,0.348
eFold (pre-trained),0.872,0.928,0.886,0.547,0.636,0.585,0.638,0.719,0.674,0.409,0.399,0.4
eFold,0.888,0.928,0.894,0.572,0.636,0.6,0.704,0.753,0.726,0.459,0.432,0.442


In [22]:
from scipy.stats import ttest_ind

ttest_ind(results_perf[(results_perf['model']=='eFold (pre-trained)') & (results_perf['dataset']=='Long ncRNA')]['F1'].tolist(),
          results_perf[(results_perf['model']=='UFold (re-trained)') & (results_perf['dataset']=='Long ncRNA')]['F1'].tolist())

TtestResult(statistic=1.6957539791280387, pvalue=0.09529470245712583, df=58.0)

In [22]:
pivot_df.to_excel('tables/T3_structure_performance.xlsx')