In [15]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

In [17]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['dataset']=='archiveII', 'dataset'] = 'ArchiveII'

In [18]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII", "viral_fragments", "lncRNA_nonFiltered"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII: Download complete. File saved at data/archiveII/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA_nonFiltered: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA_nonFiltered: Download complete. File saved at data/lncRNA_nonFiltered/data.json


In [19]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()


    TP = torch.sum(pred_matrix*target_matrix)
    PP = torch.sum(pred_matrix)
    P = torch.sum(target_matrix)
    sum_pair = PP + P

    if sum_pair == 0:
        return [1.0, 1.0, 1.0]
    else:
        return [
                (TP / PP).item(),
                (TP / P).item(),
                (2 * TP / sum_pair).item()
                ]
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [20]:
data_comparison = pd.DataFrame()

path_data = '../Figure5/saved_data_plot/results_V1'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model)).drop_duplicates(['reference'])

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    F1s = []
    Precisions = []
    Recalls = []
    for i, row in merged.iterrows():
        Precision, Recall, F1 = compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                                        pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5)
        
        F1s.append(F1)
        Precisions.append(Precision)
        Recalls.append(Recall)

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['F1'] = F1s
    merged['Precision'] = Precisions
    merged['Recall'] = Recalls
    merged['model'] = model
    print(merged.groupby('dataset')['F1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'model', 'dataset', 'F1', 'Precision', 'Recall']]])

data_comparison.loc[data_comparison['model']=='PT+FT', 'model'] = 'eFold'
data_comparison.loc[data_comparison['model']=='PT', 'model'] = 'eFold (no finetuning)'

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII', 'dataset'] = 'ArchiveII'

PT+FT
dataset
PDB                   0.886967
archiveII             0.589195
lncRNA_nonFiltered    0.477255
viral_fragments       0.708671
Name: F1, dtype: float64
primiRNA
dataset
PDB                   0.894660
archiveII             0.591931
lncRNA_nonFiltered    0.449569
viral_fragments       0.713196
Name: F1, dtype: float64
PT
dataset
PDB                   0.887260
archiveII             0.593638
lncRNA_nonFiltered    0.360292
viral_fragments       0.662598
Name: F1, dtype: float64
PT
dataset
PDB                   0.880137
archiveII             0.589553
lncRNA_nonFiltered    0.405555
viral_fragments       0.660225
Name: F1, dtype: float64
PT+FT-humanmRNA
dataset
PDB                   0.890945
archiveII             0.588552
lncRNA_nonFiltered    0.463942
viral_fragments       0.705205
Name: F1, dtype: float64


In [13]:
results_perf = pd.concat([data_comparison, result_algos.drop(columns=['length', 'structure'])])

results_perf = results_perf[results_perf.model.isin(['eFold', 'eFold (no finetuning)', 'UFold', 'SPOT-RNA', 'RNAformer'])]#.groupby(['model', 'dataset']).mean()
results_perf

Unnamed: 0,reference,model,dataset,F1,Precision,Recall
0,2N7X-2D,eFold,PDB,0.933333,0.875000,1.000000
1,8S95-2D,eFold,PDB,0.680000,0.693878,0.666667
2,2CD1-2D,eFold,PDB,1.000000,1.000000,1.000000
3,1WTT-2D,eFold,PDB,1.000000,1.000000,1.000000
4,2NCI-2D,eFold,PDB,0.888889,0.800000,1.000000
...,...,...,...,...,...,...
3149,XIST_10,UFold,Long ncRNA,0.263254,0.397790,0.196721
3150,MALAT1_0,UFold,Long ncRNA,0.272912,0.378531,0.213376
3151,XIST_11,UFold,Long ncRNA,0.182524,0.265537,0.139053
3152,NORAD1_55C,UFold,Long ncRNA,0.002683,0.006079,0.001721


In [14]:
# Group the data by model and dataset and calculate the mean for each group
grouped = results_perf.groupby(['model', 'dataset']).mean(numeric_only=True).reset_index()

# Pivot the table to create a multi-level column structure
pivot_df = pd.pivot_table(grouped, index='model', columns='dataset', values=['Precision', 'Recall', 'F1'])

# Swap the level of the columns to have dataset as the top level and the metrics as the second level
pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)

# Define the new order for the models and reorder the rows
new_order = ['UFold', 'eFold (no finetuning)', 'eFold']
pivot_df = pivot_df.reindex(new_order)

pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(['Precision', 'Recall', 'F1'], level=1)[0])[['PDB', 'ArchiveII', 'Viral mRNA', 'Long ncRNA']]

pivot_df = pivot_df.style\
            .format(precision=3)\
            .highlight_max(axis=0, props="font-weight:bold;font-color:black;")\
            .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap="viridis", text_color_threshold=0)\
            .set_properties(**{'text-align': 'center'})\
            .set_table_styles(
                        [{"selector": "th", "props": [('text-align', 'center')]},
                        ])
pivot_df


dataset,PDB,PDB,PDB,ArchiveII,ArchiveII,ArchiveII,Viral mRNA,Viral mRNA,Viral mRNA,Long ncRNA,Long ncRNA,Long ncRNA
Unnamed: 0_level_1,Precision,Recall,F1,Precision,Recall,F1,Precision,Recall,F1,Precision,Recall,F1
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
UFold,0.809,0.963,0.868,0.757,0.8,0.773,0.578,0.587,0.579,0.219,0.144,0.163
eFold (no finetuning),0.863,0.923,0.88,0.553,0.639,0.59,0.632,0.699,0.66,0.38,0.446,0.406
eFold,0.869,0.932,0.887,0.559,0.629,0.589,0.689,0.736,0.709,0.458,0.506,0.477


In [60]:
pivot_df.to_excel('tables/T3_structure_performance.xlsx')