## Ablation study of our final best model, on structure prediction

Show performance difference (Delta of F1 score) of different versions of our model, vs our best model (baseline, ∆F1=0)

- Is pretraining useful (Ribo, ArchiveII, bpRNA, RNAstralign, zuber, synthetic)
- Is finetuning useful (UTR, pri-miRNA, human_mRNA fragments?)
- Is training on everything at once better ?
- Is is useful to have a SHAPE/DMS head ?
- Pearson loss for SHAPE/DMS vs MSE
- model architecture (Evoformer, Transformer, CNN)

The test set could be a weighted average of the tree standard set (PDB, viral_fragments, lncRNA) if there is no outliers


**Assigned to**: Alberic

Use Ploty, and a white background

In [1]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [2]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII_blast", "lncRNA_nonFiltered", "viral_fragments"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII_blast: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII_blast: Download complete. File saved at data/archiveII_blast/data.json
lncRNA_nonFiltered: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA_nonFiltered: Download complete. File saved at data/lncRNA_nonFiltered/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json


In [3]:
data_comparison = pd.DataFrame()

path_data = '../Figure5/saved_data_plot/results_main'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model))

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))
    merged = merged[merged['dataset']!= 'lncRNA']

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['f1'] = f1s
    merged['model'] = model
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1']]])

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII_blast', 'dataset'] = 'archiveII'

data_comparison.loc[data_comparison['model']=='PT+FT', 'model'] = 'eFold'
data_comparison.loc[data_comparison['model']=='PT', 'model'] = 'eFold (no finetuning)'
data_comparison.loc[data_comparison['model']=='PT-UFold', 'model'] = 'no finetuning, with <br> UFold architecture'
data_comparison.loc[data_comparison['model']=='PT+FT-mRNA', 'model'] = 'finetune on mRNA'
data_comparison.loc[data_comparison['model']=='PT+FT-primiRNA', 'model'] = 'finetune on pri-miRNA'

PT+FT-primiRNA
dataset
PDB                   0.900898
archiveII_blast       0.666530
lncRNA_nonFiltered    0.407392
viral_fragments       0.738163
Name: f1, dtype: float64
ribonanza
dataset
PDB                   0.849423
archiveII_blast       0.492899
lncRNA_nonFiltered    0.298363
viral_fragments       0.640913
Name: f1, dtype: float64
RNAstralign
dataset
PDB                   0.728857
archiveII_blast       0.664344
lncRNA_nonFiltered    0.019084
viral_fragments       0.148102
Name: f1, dtype: float64
bpRNA
dataset
PDB                   0.816722
archiveII_blast       0.647147
lncRNA_nonFiltered    0.183122
viral_fragments       0.332785
Name: f1, dtype: float64
PT-UFold
dataset
PDB                   0.843829
archiveII_blast       0.627999
lncRNA_nonFiltered    0.301590
viral_fragments       0.644054
Name: f1, dtype: float64
PT+FT-mRNA
dataset
PDB                   0.890331
archiveII_blast       0.611764
lncRNA_nonFiltered    0.427032
viral_fragments       0.714323
Name: f1, dtype: flo

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

models = [
            'eFold',
            'finetune on mRNA',
            'finetune on pri-miRNA',
            'eFold (no finetuning)',
            'no finetuning, with <br> UFold architecture',
          ]

datasets = ['PDB', "archiveII", 'viral mRNA', 'long ncRNA']

baseline = data_comparison[data_comparison['model']=='eFold'].groupby('dataset')['f1'].mean()

fig = make_subplots(rows=len(models), cols=len(datasets), 
                    subplot_titles=([f'        ∆F1 for <br>      {dataset}' for dataset in datasets]), shared_yaxes=True, shared_xaxes=True,
                    vertical_spacing=0,
                    horizontal_spacing=0.08
                    )
palette = px.colors.qualitative.Pastel
colors = {'eFold': '#FF3333',
          'finetune on mRNA': palette[7],
          'finetune on pri-miRNA': palette[4],
          'eFold (no finetuning)': "rgb(148, 103, 189)",
          'no finetuning, with <br> UFold architecture': '#444444'}

for i, model in enumerate(models):
    for j, dataset in enumerate(datasets):
        
        data = data_comparison[(data_comparison['model']==model) & (data_comparison['dataset']==dataset)]

        # fig.add_trace(go.Box(x=data['f1'] - baseline[dataset], #meanline_visible=True, 
        #                      boxpoints=False,
        #                      name=model, marker_color=colors[i], showlegend=False), row=1+i, col=1+j)
        
        fig.add_trace(go.Scatter(x=[data['f1'].mean() - baseline[dataset]], y=[model], 
                                 error_x=dict(type='data', array=[data['f1'].std()], thickness=3),
                                 mode='markers', marker_color=colors[model], marker_size=7,
                                 showlegend=False), row=1+i, col=1+j)
        
fig.update_annotations(font_size=20)
fig.update_xaxes(range=[-0.4, 0.25], mirror=True,)
fig.update_layout(
                    xaxis=dict(
                        tickfont=dict(size=1)  # Set the desired font size here
                    ),
                    height=len(models)*100 +200, width=1000,
                    template='plotly_white', 
                    font_size=20, font_color='black', font_family='helvetica light')
fig.update_layout(
xaxis = dict(
tickfont = dict(size=5)))
fig.show()

In [5]:
# save pdf
fig.write_image("images/S4/a_ablation_study.pdf")

# Pretraining eFold on each database individually

In [14]:
# Creat a box plot with plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
fig = make_subplots(rows=1, cols=1, shared_xaxes=True)

# colors = ['#444444', '#888888', "rgb(148, 103, 189)", '#FF3333']
colors = [
    'rgb(176, 200, 180)', # RNAstralign
    'rgb(233, 176, 174)', # bpRNA
    '#636EFA', # ribonanza
    'rgb(0, 200, 250)', # RNAcentral-synthetic
    '#FF3333', # eFold
]
    
# # UFold results
# for i, model in enumerate(['UFold', 'MXFold2']):
#     results_dataset = result_algos[result_algos['model']==model]
#     fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['F1'], 
#                             name=model, marker_color=colors[i], 
#                             meanline_visible=True, points=False))

for i, model in enumerate(['RNAstralign', 'bpRNA', 'ribonanza', 'RNAcentral-synthetic', 'eFold']):
    results_dataset = data_comparison[data_comparison['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['f1'], 
                            name=model, marker_color=colors[i], 
                            meanline_visible=True, points=False))
    
fig.update_layout(
                # title='F1 score on test sets for pretraining vs pretraining + fine-tuning models', 
                yaxis_title='F1 score', xaxis_title='',
                  violinmode='group', yaxis_range=[0, 1],
                  width=1200, height=400,
                  template='plotly_white', font_size=18, font_color='black', font_family='times new roman',)
# fig.update_xaxes(categoryorder='array', categoryarray= ['RNAstructure', 'EternaFold', 'MXFold2', 'UFold'])

# make legend horizontal
fig.update_layout(legend_orientation="h", legend=dict(x=0.2, y=1.3), 
                  width=1000, height=400)

fig.show()

# save pdf
fig.write_image("images/S4/b_database_training.pdf")