## Show distribution of structure prediction performance

Reuse same format as the algo_benchmark plot of Figure 1.

Add our algorithm and hopefully show less performance gap

**Assigned to**: Alberic

Use Ploty, and a white background

In [1]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

In [2]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [3]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII", "viral_fragments", "lncRNA"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII: Download complete. File saved at data/archiveII/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA: Download complete. File saved at data/lncRNA/data.json


In [4]:
data_comparison = pd.DataFrame()

path_data = 'saved_data_plot/results_V2'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model))
    print(os.path.join(path_data, model))

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['f1'] = f1s
    merged['model'] = model
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1', 'structure_true', 'structure_pred']]])

data_comparison.loc[data_comparison['model']=='eFoldPT+FT', 'model'] = 'eFold'
data_comparison.loc[data_comparison['model']=='eFoldPT', 'model'] = 'eFold (pre-trained)'
data_comparison.loc[data_comparison['model']=='UFoldPT+FT', 'model'] = 'UFold (re-trained)'

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA', 'dataset'] = 'Long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII', 'dataset'] = 'ArchiveII'

saved_data_plot/results_V2/test_results_eFoldPT.feather
eFoldPT
dataset
PDB                0.886010
archiveII          0.584986
lncRNA             0.400431
viral_fragments    0.674164
Name: f1, dtype: float64
saved_data_plot/results_V2/test_results_eFoldPT+FT_human-mRNA.feather
human-mRNA
dataset
PDB                0.890150
archiveII          0.580508
lncRNA             0.402036
viral_fragments    0.682623
Name: f1, dtype: float64
saved_data_plot/results_V2/test_results_UFoldPT+FT.feather
UFoldPT+FT
dataset
PDB                0.890532
archiveII          0.607041
lncRNA             0.347520
viral_fragments    0.674862
Name: f1, dtype: float64
saved_data_plot/results_V2/test_results_eFoldPT+FT.feather
eFoldPT+FT
dataset
PDB                0.894044
archiveII          0.599568
lncRNA             0.442343
viral_fragments    0.725702
Name: f1, dtype: float64
saved_data_plot/results_V2/test_results_eFoldPT+FT_pri-miRNA.feather
pri-miRNA
dataset
PDB                0.891003
archiveII          0

In [5]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='archiveII', 'dataset'] = 'ArchiveII'
result_algos.loc[result_algos['dataset']=='lncRNA', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['model']=='UFold', 'model'] = 'UFold (original)'

In [6]:
# Creat a box plot with plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
# fig = make_subplots(rows=2, cols=1)
fig = go.Figure()

colors = ['#222222', '#888888']
    
# UFold results
for i, model in enumerate(['SPOT-RNA', 'UFold (original)']):
    results_dataset = result_algos[result_algos['model']==model]
    if model == 'RNAformer': 
       
       fig.add_trace(go.Violin(x=results_dataset[results_dataset['dataset']!='Long ncRNA']['dataset'], y=results_dataset[results_dataset['dataset']!='Long ncRNA']['F1'], 
                            name=model, marker_color=colors[i], 
                            meanline_visible=True, points=False))
       
       fig.add_trace(go.Violin(x=results_dataset[results_dataset['dataset']=='Long ncRNA']['dataset'], y=results_dataset[results_dataset['dataset']=='Long ncRNA']['F1'], 
                            marker_color=colors[i], showlegend=False,
                            meanline_visible=True, points=False))
    else:       
      fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['F1'], 
                              name=model, marker_color=colors[i], 
                              meanline_visible=True, points=False))

colors = ['blue', "rgb(148, 103, 189)", '#FF3333']
for i, model in enumerate(['UFold (re-trained)', 'eFold (pre-trained)', 'eFold']):
    results_dataset = data_comparison[data_comparison['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['f1'], 
                            name=model, 
                            marker_color=colors[i], 
                            meanline_visible=True, points=False))
    
fig.update_layout(
                # title='F1 score on test sets for pretraining vs pretraining + fine-tuning models', 
                yaxis_title='F1 score', xaxis_title='',
                  violinmode='group', yaxis_range=[0, 1],
                  width=1300, height=650,
                  template='plotly_white', font_size=20, font_color='black', font_family='times new roman',)
# fig.update_xaxes(categoryorder='array', categoryarray= ['RNAstructure', 'EternaFold', 'MXFold2', 'UFold'])

# make legend horizontal
fig.update_layout(legend_orientation="h", legend=dict(x=0.1, y=1.22), height=380)
fig.show()

In [7]:
# save pdf
# fig.write_image("images/a_structure_performance.pdf")

In [16]:
(result_algos[(result_algos['dataset']=='PDB') & (result_algos['model']=='SPOT-RNA')]['F1']<0.8).mean()

0.15774647887323945

In [11]:
result_algos.model.unique()

array(['RNAstructure', 'MXFold2', 'UFold (original)', 'EternaFold',
       'SPOT-RNA', 'E2EFold', 'CNNFold', 'NeuralFold', 'RNAformer'],
      dtype=object)