In [11]:
import pandas as pd
import numpy as np
import torch

import os

from rouskinhf import get_dataset

In [12]:
def compute_metrics(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()


    TP = torch.sum(pred_matrix*target_matrix)
    TN = torch.sum((1-pred_matrix)*(1-target_matrix))
    FN = torch.sum((1-pred_matrix)*target_matrix)
    FP = torch.sum(pred_matrix*(1-target_matrix))

    PP = torch.sum(pred_matrix)
    P = torch.sum(target_matrix)
    sum_pair = PP + P

    MCC_denominator = torch.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
    MCC = (TP*TN - FP*FN)/MCC_denominator if MCC_denominator != 0 else torch.tensor([0.])

    if sum_pair == 0:
        return [1.0, 1.0, 1.0, 1.0]
    else:
        return [
                (TP / PP).item(),
                (TP / P).item(),
                (2 * TP / sum_pair).item(),
                MCC.item()
                ]
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [13]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII", "viral_fragments", "lncRNA"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII: Download complete. File saved at data/archiveII/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA: Download complete. File saved at data/lncRNA/data.json


In [14]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='archiveII_blast', 'dataset'] = 'ArchiveII'

In [15]:
prediction = pd.read_feather('../Figure5/saved_data_plot/results_V2/test_results_eFoldPT+FT.feather')

merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

f1s = []
mccs = []
precisions = []
recalls = []
for i, row in merged.iterrows():
    precision, recall, F1, MCC = compute_metrics(torch.tensor(np.stack(row['structure_pred'])), 
                                                    pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5)
    
    f1s.append(F1)
    mccs.append(MCC)
    precisions.append(precision)
    recalls.append(recall)

merged['F1'] = f1s
merged['model'] = 'eFold'
merged['Precision'] = precisions
merged['Recall'] = recalls
merged['MCC'] = mccs

merged.loc[merged['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'


In [19]:
result_algos = pd.concat([result_algos, merged[['model', 'dataset', 'Precision', 'Recall', 'F1', 'MCC']]])
result_algos.loc[result_algos['dataset']=='lncRNA', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['model']=='RNAstructure', 'model'] = 'RNAstructure Fold'

In [21]:
models = ['RNAstructure Fold', 'EternaFold', 'MXFold2', 'UFold', 'E2EFold', 'SPOT-RNA', 'CNNFold', 'NeuralFold', 'RNAformer', 'eFold']
metrics = ['Precision', 'Recall', 'F1', 'MCC']

# Group the data by model and dataset and calculate the mean for each group
grouped = result_algos.groupby(['model', 'dataset']).mean(numeric_only=True).reset_index()

# Pivot the table to create a multi-level column structure
pivot_df = pd.pivot_table(grouped, index='model', columns='dataset', values=metrics)

# Swap the level of the columns to have dataset as the top level and the metrics as the second level
pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)

# Define the new order for the models and reorder the rows
pivot_df = pivot_df.reindex(models)

pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(metrics, level=1)[0])[['PDB', 'archiveII', 'Viral mRNA', 'Long ncRNA']]

pivot_df = pivot_df.style\
            .format(precision=2)\
            .highlight_max(axis=0, props="font-weight:bold;font-color:black;")\
            .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap="viridis", text_color_threshold=0)\
            .set_properties(**{'text-align': 'center'})\
            .set_table_styles(
                        [{"selector": "th", "props": [('text-align', 'center')]},
                        # {"selector": "tbody td", "props": [("border", "1px solid")]},
                        # {"selector": "th", "props": [("border", "1px solid")]}
                        ])
pivot_df

dataset,PDB,PDB,PDB,PDB,archiveII,archiveII,archiveII,archiveII,Viral mRNA,Viral mRNA,Viral mRNA,Viral mRNA,Long ncRNA,Long ncRNA,Long ncRNA,Long ncRNA
Unnamed: 0_level_1,Precision,Recall,F1,MCC,Precision,Recall,F1,MCC,Precision,Recall,F1,MCC,Precision,Recall,F1,MCC
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
RNAstructure Fold,0.9,0.91,0.89,0.89,0.55,0.6,0.57,0.57,0.69,0.74,0.71,0.71,0.46,0.52,0.49,0.49
EternaFold,0.88,0.91,0.88,0.88,0.57,0.64,0.6,0.6,0.75,0.81,0.77,0.78,0.45,0.47,0.46,0.46
MXFold2,0.91,0.93,0.9,0.91,0.73,0.76,0.74,0.74,0.7,0.72,0.71,0.71,0.41,0.43,0.42,0.42
UFold,0.81,0.97,0.87,0.87,0.83,0.89,0.85,0.85,0.58,0.59,0.58,0.58,0.22,0.14,0.16,0.17
E2EFold,0.21,0.1,0.13,0.13,0.29,0.21,0.24,0.24,0.04,0.03,0.03,0.03,0.03,0.03,0.03,0.03
SPOT-RNA,0.85,0.92,0.87,0.87,0.69,0.73,0.7,0.7,0.68,0.5,0.56,0.58,0.34,0.21,0.26,0.27
CNNFold,0.92,0.6,0.65,0.67,0.91,0.43,0.52,0.57,0.56,0.04,0.07,0.12,0.39,0.01,0.02,0.07
NeuralFold,0.81,0.84,0.81,0.81,0.72,0.72,0.72,0.72,0.26,0.24,0.25,0.25,0.15,0.13,0.14,0.14
RNAformer,0.74,0.96,0.82,0.82,0.54,0.84,0.65,0.67,0.48,0.5,0.48,0.48,0.27,0.08,0.07,0.07
eFold,0.89,0.93,0.89,0.89,0.57,0.64,0.6,0.6,0.7,0.75,0.73,0.73,0.46,0.43,0.44,0.44


In [23]:
pivot_df.to_excel('tables/ST1_full_benchmark.xlsx')