## Show distribution of structure prediction performance

Reuse same format as the algo_benchmark plot of Figure 1.

Add our algorithm and hopefully show less performance gap

**Assigned to**: Alberic

Use Ploty, and a white background

In [1]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

In [2]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [4]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII", "viral_fragments", "lncRNA_nonFiltered"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data.json:   0%|          | 0.00/3.23M [00:00<?, ?B/s]

archiveII: Download complete. File saved at data/archiveII/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA_nonFiltered: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA_nonFiltered: Download complete. File saved at data/lncRNA_nonFiltered/data.json


In [32]:
data_comparison = pd.DataFrame()

path_data = 'saved_data_plot/results_V1'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model))

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['f1'] = f1s
    merged['model'] = model
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1', 'structure_true', 'structure_pred']]])

data_comparison.loc[data_comparison['model']=='PT+FT', 'model'] = 'eFold'
data_comparison.loc[data_comparison['model']=='PT', 'model'] = 'eFold (no finetuning)'

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII_blast', 'dataset'] = 'ArchiveII'

PT+FT
dataset
PDB                   0.886967
archiveII             0.589195
lncRNA_nonFiltered    0.477255
viral_fragments       0.708671
Name: f1, dtype: float64
primiRNA
dataset
PDB                   0.894660
archiveII             0.591931
lncRNA_nonFiltered    0.449569
viral_fragments       0.713196
Name: f1, dtype: float64
PT
dataset
PDB                   0.887260
archiveII             0.593638
lncRNA_nonFiltered    0.360292
viral_fragments       0.662598
Name: f1, dtype: float64
PT
dataset
PDB                   0.880137
archiveII             0.589553
lncRNA_nonFiltered    0.405555
viral_fragments       0.660225
Name: f1, dtype: float64
PT+FT-humanmRNA
dataset
PDB                   0.890945
archiveII             0.588552
lncRNA_nonFiltered    0.463942
viral_fragments       0.705205
Name: f1, dtype: float64


In [7]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
# result_algos = result_algos[result_algos['dataset']!= 'lncRNA']
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['dataset']=='archiveII_blast', 'dataset'] = 'ArchiveII'

In [33]:
# Creat a box plot with plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
fig = make_subplots(rows=1, cols=1, shared_xaxes=True)

colors = ['#333333', '#666666', '#999999']
    
# UFold results
for i, model in enumerate(['RNAformer', 'SPOT-RNA', 'UFold']):
    results_dataset = result_algos[result_algos['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['F1'], 
                            name=model, marker_color=colors[i], 
                            meanline_visible=True, points=False))

colors = ["rgb(148, 103, 189)", '#FF3333']
for i, model in enumerate(['eFold (no finetuning)', 'eFold']):
    results_dataset = data_comparison[data_comparison['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['f1'], 
                            name=model, 
                            marker_color=colors[i], 
                            meanline_visible=True, points=False))
    
fig.update_layout(
                # title='F1 score on test sets for pretraining vs pretraining + fine-tuning models', 
                yaxis_title='F1 score', xaxis_title='',
                  violinmode='group', yaxis_range=[0, 1],
                  width=1300, height=400,
                  template='plotly_white', font_size=20, font_color='black', font_family='times new roman',)
# fig.update_xaxes(categoryorder='array', categoryarray= ['RNAstructure', 'EternaFold', 'MXFold2', 'UFold'])

# make legend horizontal
fig.update_layout(legend_orientation="h", legend=dict(x=0.1, y=1.3))
fig.show()

In [36]:
# save pdf
fig.write_image("images/a_structure_performance.pdf")